1use async_trait::async_trait;
73use axum::extract::ws::Message;
74use futures_util::{SinkExt, StreamExt};
75use regex::Regex;
76use serde_json::Value;
77use std::collections::{HashMap, HashSet};
78use std::sync::Arc;
79use tokio::sync::{broadcast, Mutex, RwLock};
80
81pub type HandlerResult<T> = Result<T, HandlerError>;
83
84#[derive(Debug, thiserror::Error)]
86pub enum HandlerError {
87 #[error("Failed to send message: {0}")]
89 SendError(String),
90
91 #[error("Failed to parse JSON: {0}")]
93 JsonError(#[from] serde_json::Error),
94
95 #[error("Pattern matching error: {0}")]
97 PatternError(String),
98
99 #[error("Room operation failed: {0}")]
101 RoomError(String),
102
103 #[error("Connection error: {0}")]
105 ConnectionError(String),
106
107 #[error("Handler error: {0}")]
109 Generic(String),
110}
111
112#[derive(Debug, Clone)]
114pub enum WsMessage {
115 Text(String),
117 Binary(Vec<u8>),
119 Ping(Vec<u8>),
121 Pong(Vec<u8>),
123 Close,
125}
126
127impl From<Message> for WsMessage {
128 fn from(msg: Message) -> Self {
129 match msg {
130 Message::Text(text) => WsMessage::Text(text.to_string()),
131 Message::Binary(data) => WsMessage::Binary(data.to_vec()),
132 Message::Ping(data) => WsMessage::Ping(data.to_vec()),
133 Message::Pong(data) => WsMessage::Pong(data.to_vec()),
134 Message::Close(_) => WsMessage::Close,
135 }
136 }
137}
138
139impl From<WsMessage> for Message {
140 fn from(msg: WsMessage) -> Self {
141 match msg {
142 WsMessage::Text(text) => Message::Text(text.into()),
143 WsMessage::Binary(data) => Message::Binary(data.into()),
144 WsMessage::Ping(data) => Message::Ping(data.into()),
145 WsMessage::Pong(data) => Message::Pong(data.into()),
146 WsMessage::Close => Message::Close(None),
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
153pub enum MessagePattern {
154 Regex(Regex),
156 JsonPath(String),
158 Exact(String),
160 Any,
162}
163
164impl MessagePattern {
165 pub fn regex(pattern: &str) -> HandlerResult<Self> {
167 Ok(MessagePattern::Regex(
168 Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
169 ))
170 }
171
172 pub fn jsonpath(query: &str) -> Self {
174 MessagePattern::JsonPath(query.to_string())
175 }
176
177 pub fn exact(text: &str) -> Self {
179 MessagePattern::Exact(text.to_string())
180 }
181
182 pub fn any() -> Self {
184 MessagePattern::Any
185 }
186
187 pub fn matches(&self, text: &str) -> bool {
189 match self {
190 MessagePattern::Regex(re) => re.is_match(text),
191 MessagePattern::JsonPath(query) => {
192 if let Ok(json) = serde_json::from_str::<Value>(text) {
194 if let Ok(selector) = jsonpath::Selector::new(query) {
196 let results: Vec<_> = selector.find(&json).collect();
197 !results.is_empty()
198 } else {
199 false
200 }
201 } else {
202 false
203 }
204 }
205 MessagePattern::Exact(expected) => text == expected,
206 MessagePattern::Any => true,
207 }
208 }
209
210 pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
212 if let Ok(json) = serde_json::from_str::<Value>(text) {
213 if let Ok(selector) = jsonpath::Selector::new(query) {
214 let results: Vec<_> = selector.find(&json).collect();
215 results.first().cloned().cloned()
216 } else {
217 None
218 }
219 } else {
220 None
221 }
222 }
223}
224
225pub type ConnectionId = String;
227
228#[derive(Clone)]
230pub struct RoomManager {
231 rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
232 connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
233 broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
234}
235
236impl RoomManager {
237 pub fn new() -> Self {
239 Self {
240 rooms: Arc::new(RwLock::new(HashMap::new())),
241 connections: Arc::new(RwLock::new(HashMap::new())),
242 broadcasters: Arc::new(RwLock::new(HashMap::new())),
243 }
244 }
245
246 pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
248 let mut rooms = self.rooms.write().await;
249 let mut connections = self.connections.write().await;
250
251 rooms
252 .entry(room.to_string())
253 .or_insert_with(HashSet::new)
254 .insert(conn_id.to_string());
255
256 connections
257 .entry(conn_id.to_string())
258 .or_insert_with(HashSet::new)
259 .insert(room.to_string());
260
261 Ok(())
262 }
263
264 pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
266 let mut rooms = self.rooms.write().await;
267 let mut connections = self.connections.write().await;
268
269 if let Some(room_members) = rooms.get_mut(room) {
270 room_members.remove(conn_id);
271 if room_members.is_empty() {
272 rooms.remove(room);
273 }
274 }
275
276 if let Some(conn_rooms) = connections.get_mut(conn_id) {
277 conn_rooms.remove(room);
278 if conn_rooms.is_empty() {
279 connections.remove(conn_id);
280 }
281 }
282
283 Ok(())
284 }
285
286 pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
288 let mut connections = self.connections.write().await;
289 if let Some(conn_rooms) = connections.remove(conn_id) {
290 let mut rooms = self.rooms.write().await;
291 for room in conn_rooms {
292 if let Some(room_members) = rooms.get_mut(&room) {
293 room_members.remove(conn_id);
294 if room_members.is_empty() {
295 rooms.remove(&room);
296 }
297 }
298 }
299 }
300 Ok(())
301 }
302
303 pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
305 let rooms = self.rooms.read().await;
306 rooms
307 .get(room)
308 .map(|members| members.iter().cloned().collect())
309 .unwrap_or_default()
310 }
311
312 pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
314 let connections = self.connections.read().await;
315 connections
316 .get(conn_id)
317 .map(|rooms| rooms.iter().cloned().collect())
318 .unwrap_or_default()
319 }
320
321 pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
323 let mut broadcasters = self.broadcasters.write().await;
324 broadcasters
325 .entry(room.to_string())
326 .or_insert_with(|| {
327 let (tx, _) = broadcast::channel(1024);
328 tx
329 })
330 .clone()
331 }
332}
333
334impl Default for RoomManager {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340pub struct WsContext {
342 pub connection_id: ConnectionId,
344 pub path: String,
346 room_manager: RoomManager,
348 message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
350 metadata: Arc<RwLock<HashMap<String, Value>>>,
352}
353
354impl WsContext {
355 pub fn new(
357 connection_id: ConnectionId,
358 path: String,
359 room_manager: RoomManager,
360 message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
361 ) -> Self {
362 Self {
363 connection_id,
364 path,
365 room_manager,
366 message_tx,
367 metadata: Arc::new(RwLock::new(HashMap::new())),
368 }
369 }
370
371 pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
373 self.message_tx
374 .send(Message::Text(text.to_string().into()))
375 .map_err(|e| HandlerError::SendError(e.to_string()))
376 }
377
378 pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
380 self.message_tx
381 .send(Message::Binary(data.into()))
382 .map_err(|e| HandlerError::SendError(e.to_string()))
383 }
384
385 pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
387 let text = serde_json::to_string(value)?;
388 self.send_text(&text).await
389 }
390
391 pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
393 self.room_manager.join(&self.connection_id, room).await
394 }
395
396 pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
398 self.room_manager.leave(&self.connection_id, room).await
399 }
400
401 pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
403 let broadcaster = self.room_manager.get_broadcaster(room).await;
404 broadcaster
405 .send(text.to_string())
406 .map_err(|e| HandlerError::RoomError(e.to_string()))?;
407 Ok(())
408 }
409
410 pub async fn get_rooms(&self) -> Vec<String> {
412 self.room_manager.get_connection_rooms(&self.connection_id).await
413 }
414
415 pub async fn set_metadata(&self, key: &str, value: Value) {
417 let mut metadata = self.metadata.write().await;
418 metadata.insert(key.to_string(), value);
419 }
420
421 pub async fn get_metadata(&self, key: &str) -> Option<Value> {
423 let metadata = self.metadata.read().await;
424 metadata.get(key).cloned()
425 }
426}
427
428#[async_trait]
430pub trait WsHandler: Send + Sync {
431 async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
433 Ok(())
434 }
435
436 async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
438
439 async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
441 Ok(())
442 }
443
444 fn handles_path(&self, _path: &str) -> bool {
446 true }
448}
449
450type MessageHandler = Box<dyn Fn(String) -> Option<String> + Send + Sync>;
452
453pub struct MessageRouter {
455 routes: Vec<(MessagePattern, MessageHandler)>,
456}
457
458impl MessageRouter {
459 pub fn new() -> Self {
461 Self { routes: Vec::new() }
462 }
463
464 pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
466 where
467 F: Fn(String) -> Option<String> + Send + Sync + 'static,
468 {
469 self.routes.push((pattern, Box::new(handler)));
470 self
471 }
472
473 pub fn route(&self, text: &str) -> Option<String> {
475 for (pattern, handler) in &self.routes {
476 if pattern.matches(text) {
477 if let Some(response) = handler(text.to_string()) {
478 return Some(response);
479 }
480 }
481 }
482 None
483 }
484}
485
486impl Default for MessageRouter {
487 fn default() -> Self {
488 Self::new()
489 }
490}
491
492pub struct HandlerRegistry {
494 handlers: Vec<Arc<dyn WsHandler>>,
495 hot_reload_enabled: bool,
496}
497
498impl HandlerRegistry {
499 pub fn new() -> Self {
501 Self {
502 handlers: Vec::new(),
503 hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
504 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
505 .unwrap_or(false),
506 }
507 }
508
509 pub fn with_hot_reload() -> Self {
511 Self {
512 handlers: Vec::new(),
513 hot_reload_enabled: true,
514 }
515 }
516
517 pub fn is_hot_reload_enabled(&self) -> bool {
519 self.hot_reload_enabled
520 }
521
522 pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
524 self.handlers.push(Arc::new(handler));
525 self
526 }
527
528 pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
530 self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
531 }
532
533 pub fn has_handler_for(&self, path: &str) -> bool {
535 self.handlers.iter().any(|h| h.handles_path(path))
536 }
537
538 pub fn clear(&mut self) {
540 self.handlers.clear();
541 }
542
543 pub fn len(&self) -> usize {
545 self.handlers.len()
546 }
547
548 pub fn is_empty(&self) -> bool {
550 self.handlers.is_empty()
551 }
552}
553
554impl Default for HandlerRegistry {
555 fn default() -> Self {
556 Self::new()
557 }
558}
559
560#[derive(Clone)]
562pub struct PassthroughConfig {
563 pub pattern: MessagePattern,
565 pub upstream_url: String,
567}
568
569impl PassthroughConfig {
570 pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
572 Self {
573 pattern,
574 upstream_url,
575 }
576 }
577
578 pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
580 Ok(Self {
581 pattern: MessagePattern::regex(regex)?,
582 upstream_url,
583 })
584 }
585}
586
587pub struct PassthroughHandler {
589 config: PassthroughConfig,
590 upstream_tx: Mutex<Option<UpstreamSender>>,
592}
593
594type UpstreamSender = futures_util::stream::SplitSink<
595 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
596 tokio_tungstenite::tungstenite::Message,
597>;
598
599impl PassthroughHandler {
600 pub fn new(config: PassthroughConfig) -> Self {
602 Self {
603 config,
604 upstream_tx: Mutex::new(None),
605 }
606 }
607
608 pub fn should_passthrough(&self, text: &str) -> bool {
610 self.config.pattern.matches(text)
611 }
612
613 pub fn upstream_url(&self) -> &str {
615 &self.config.upstream_url
616 }
617
618 async fn ensure_connected(
621 &self,
622 client_tx: &tokio::sync::mpsc::UnboundedSender<Message>,
623 ) -> HandlerResult<()> {
624 let mut guard = self.upstream_tx.lock().await;
625 if guard.is_some() {
626 return Ok(());
627 }
628
629 let url = &self.config.upstream_url;
630 tracing::info!(upstream = %url, "Connecting to upstream WebSocket server");
631
632 let (ws_stream, _response) = tokio_tungstenite::connect_async(url)
633 .await
634 .map_err(|e| HandlerError::ConnectionError(format!("Upstream connect failed: {e}")))?;
635
636 let (write, mut read) = ws_stream.split();
637 *guard = Some(write);
638
639 let client_tx = client_tx.clone();
641 tokio::spawn(async move {
642 while let Some(Ok(msg)) = read.next().await {
643 let axum_msg = match msg {
644 tokio_tungstenite::tungstenite::Message::Text(t) => {
645 Message::Text(t.to_string().into())
646 }
647 tokio_tungstenite::tungstenite::Message::Binary(b) => {
648 Message::Binary(b.to_vec().into())
649 }
650 tokio_tungstenite::tungstenite::Message::Ping(p) => {
651 Message::Ping(p.to_vec().into())
652 }
653 tokio_tungstenite::tungstenite::Message::Pong(p) => {
654 Message::Pong(p.to_vec().into())
655 }
656 tokio_tungstenite::tungstenite::Message::Close(_) => {
657 break;
658 }
659 tokio_tungstenite::tungstenite::Message::Frame(_) => continue,
660 };
661 if client_tx.send(axum_msg).is_err() {
662 break;
663 }
664 }
665 tracing::debug!("Upstream reader task finished");
666 });
667
668 Ok(())
669 }
670}
671
672#[async_trait]
673impl WsHandler for PassthroughHandler {
674 async fn on_connect(&self, ctx: &mut WsContext) -> HandlerResult<()> {
675 self.ensure_connected(&ctx.message_tx).await
676 }
677
678 async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
679 match &msg {
680 WsMessage::Text(text) if self.should_passthrough(text) => {
681 self.ensure_connected(&ctx.message_tx).await?;
682 let mut guard = self.upstream_tx.lock().await;
683 if let Some(ref mut writer) = *guard {
684 writer
685 .send(tokio_tungstenite::tungstenite::Message::Text(text.clone().into()))
686 .await
687 .map_err(|e| {
688 HandlerError::SendError(format!("Upstream send failed: {e}"))
689 })?;
690 }
691 }
692 WsMessage::Binary(data) => {
693 self.ensure_connected(&ctx.message_tx).await?;
694 let mut guard = self.upstream_tx.lock().await;
695 if let Some(ref mut writer) = *guard {
696 writer
697 .send(tokio_tungstenite::tungstenite::Message::Binary(data.clone().into()))
698 .await
699 .map_err(|e| {
700 HandlerError::SendError(format!("Upstream send failed: {e}"))
701 })?;
702 }
703 }
704 _ => {}
705 }
706 Ok(())
707 }
708
709 async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
710 let mut guard = self.upstream_tx.lock().await;
711 if let Some(mut writer) = guard.take() {
712 let _ = writer.send(tokio_tungstenite::tungstenite::Message::Close(None)).await;
713 }
714 Ok(())
715 }
716}
717
718#[cfg(test)]
719mod tests {
720 use super::*;
721
722 #[test]
725 fn test_ws_message_text_from_axum() {
726 let axum_msg = Message::Text("hello".to_string().into());
727 let ws_msg: WsMessage = axum_msg.into();
728 match ws_msg {
729 WsMessage::Text(text) => assert_eq!(text, "hello"),
730 _ => panic!("Expected Text message"),
731 }
732 }
733
734 #[test]
735 fn test_ws_message_binary_from_axum() {
736 let data = vec![1, 2, 3, 4];
737 let axum_msg = Message::Binary(data.clone().into());
738 let ws_msg: WsMessage = axum_msg.into();
739 match ws_msg {
740 WsMessage::Binary(bytes) => assert_eq!(bytes, data),
741 _ => panic!("Expected Binary message"),
742 }
743 }
744
745 #[test]
746 fn test_ws_message_ping_from_axum() {
747 let data = vec![1, 2];
748 let axum_msg = Message::Ping(data.clone().into());
749 let ws_msg: WsMessage = axum_msg.into();
750 match ws_msg {
751 WsMessage::Ping(bytes) => assert_eq!(bytes, data),
752 _ => panic!("Expected Ping message"),
753 }
754 }
755
756 #[test]
757 fn test_ws_message_pong_from_axum() {
758 let data = vec![3, 4];
759 let axum_msg = Message::Pong(data.clone().into());
760 let ws_msg: WsMessage = axum_msg.into();
761 match ws_msg {
762 WsMessage::Pong(bytes) => assert_eq!(bytes, data),
763 _ => panic!("Expected Pong message"),
764 }
765 }
766
767 #[test]
768 fn test_ws_message_close_from_axum() {
769 let axum_msg = Message::Close(None);
770 let ws_msg: WsMessage = axum_msg.into();
771 assert!(matches!(ws_msg, WsMessage::Close));
772 }
773
774 #[test]
775 fn test_ws_message_text_to_axum() {
776 let ws_msg = WsMessage::Text("hello".to_string());
777 let axum_msg: Message = ws_msg.into();
778 assert!(matches!(axum_msg, Message::Text(_)));
779 }
780
781 #[test]
782 fn test_ws_message_binary_to_axum() {
783 let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
784 let axum_msg: Message = ws_msg.into();
785 assert!(matches!(axum_msg, Message::Binary(_)));
786 }
787
788 #[test]
789 fn test_ws_message_close_to_axum() {
790 let ws_msg = WsMessage::Close;
791 let axum_msg: Message = ws_msg.into();
792 assert!(matches!(axum_msg, Message::Close(_)));
793 }
794
795 #[test]
798 fn test_message_pattern_regex() {
799 let pattern = MessagePattern::regex(r"^hello").unwrap();
800 assert!(pattern.matches("hello world"));
801 assert!(!pattern.matches("goodbye world"));
802 }
803
804 #[test]
805 fn test_message_pattern_regex_invalid() {
806 let result = MessagePattern::regex(r"[invalid");
807 assert!(result.is_err());
808 }
809
810 #[test]
811 fn test_message_pattern_exact() {
812 let pattern = MessagePattern::exact("hello");
813 assert!(pattern.matches("hello"));
814 assert!(!pattern.matches("hello world"));
815 }
816
817 #[test]
818 fn test_message_pattern_jsonpath() {
819 let pattern = MessagePattern::jsonpath("$.type");
820 assert!(pattern.matches(r#"{"type": "message"}"#));
821 assert!(!pattern.matches(r#"{"name": "test"}"#));
822 }
823
824 #[test]
825 fn test_message_pattern_jsonpath_nested() {
826 let pattern = MessagePattern::jsonpath("$.user.name");
827 assert!(pattern.matches(r#"{"user": {"name": "John"}}"#));
828 assert!(!pattern.matches(r#"{"user": {"email": "john@example.com"}}"#));
829 }
830
831 #[test]
832 fn test_message_pattern_jsonpath_invalid_json() {
833 let pattern = MessagePattern::jsonpath("$.type");
834 assert!(!pattern.matches("not json"));
835 }
836
837 #[test]
838 fn test_message_pattern_any() {
839 let pattern = MessagePattern::any();
840 assert!(pattern.matches("anything"));
841 assert!(pattern.matches(""));
842 assert!(pattern.matches(r#"{"json": true}"#));
843 }
844
845 #[test]
846 fn test_message_pattern_extract() {
847 let pattern = MessagePattern::jsonpath("$.type");
848 let result = pattern.extract(r#"{"type": "greeting", "data": "hello"}"#, "$.type");
849 assert_eq!(result, Some(serde_json::json!("greeting")));
850 }
851
852 #[test]
853 fn test_message_pattern_extract_nested() {
854 let pattern = MessagePattern::any();
855 let result = pattern.extract(r#"{"user": {"id": 123}}"#, "$.user.id");
856 assert_eq!(result, Some(serde_json::json!(123)));
857 }
858
859 #[test]
860 fn test_message_pattern_extract_not_found() {
861 let pattern = MessagePattern::any();
862 let result = pattern.extract(r#"{"type": "message"}"#, "$.nonexistent");
863 assert!(result.is_none());
864 }
865
866 #[test]
867 fn test_message_pattern_extract_invalid_json() {
868 let pattern = MessagePattern::any();
869 let result = pattern.extract("not json", "$.type");
870 assert!(result.is_none());
871 }
872
873 #[tokio::test]
876 async fn test_room_manager() {
877 let manager = RoomManager::new();
878
879 manager.join("conn1", "room1").await.unwrap();
881 manager.join("conn1", "room2").await.unwrap();
882 manager.join("conn2", "room1").await.unwrap();
883
884 let room1_members = manager.get_room_members("room1").await;
886 assert_eq!(room1_members.len(), 2);
887 assert!(room1_members.contains(&"conn1".to_string()));
888 assert!(room1_members.contains(&"conn2".to_string()));
889
890 let conn1_rooms = manager.get_connection_rooms("conn1").await;
892 assert_eq!(conn1_rooms.len(), 2);
893 assert!(conn1_rooms.contains(&"room1".to_string()));
894 assert!(conn1_rooms.contains(&"room2".to_string()));
895
896 manager.leave("conn1", "room1").await.unwrap();
898 let room1_members = manager.get_room_members("room1").await;
899 assert_eq!(room1_members.len(), 1);
900 assert!(room1_members.contains(&"conn2".to_string()));
901
902 manager.leave_all("conn1").await.unwrap();
904 let conn1_rooms = manager.get_connection_rooms("conn1").await;
905 assert_eq!(conn1_rooms.len(), 0);
906 }
907
908 #[tokio::test]
909 async fn test_room_manager_default() {
910 let manager = RoomManager::default();
911 manager.join("conn1", "room1").await.unwrap();
913 let members = manager.get_room_members("room1").await;
914 assert_eq!(members.len(), 1);
915 }
916
917 #[tokio::test]
918 async fn test_room_manager_empty_room() {
919 let manager = RoomManager::new();
920 let members = manager.get_room_members("nonexistent").await;
921 assert!(members.is_empty());
922 }
923
924 #[tokio::test]
925 async fn test_room_manager_empty_connection() {
926 let manager = RoomManager::new();
927 let rooms = manager.get_connection_rooms("nonexistent").await;
928 assert!(rooms.is_empty());
929 }
930
931 #[tokio::test]
932 async fn test_room_manager_leave_nonexistent() {
933 let manager = RoomManager::new();
934 let result = manager.leave("conn1", "room1").await;
936 assert!(result.is_ok());
937 }
938
939 #[tokio::test]
940 async fn test_room_manager_broadcaster() {
941 let manager = RoomManager::new();
942 manager.join("conn1", "room1").await.unwrap();
943
944 let broadcaster = manager.get_broadcaster("room1").await;
945 let mut receiver = broadcaster.subscribe();
946
947 broadcaster.send("hello".to_string()).unwrap();
949
950 let msg = receiver.recv().await.unwrap();
952 assert_eq!(msg, "hello");
953 }
954
955 #[tokio::test]
956 async fn test_room_manager_room_cleanup_on_last_leave() {
957 let manager = RoomManager::new();
958 manager.join("conn1", "room1").await.unwrap();
959 manager.leave("conn1", "room1").await.unwrap();
960
961 let members = manager.get_room_members("room1").await;
963 assert!(members.is_empty());
964 }
965
966 #[test]
969 fn test_message_router() {
970 let mut router = MessageRouter::new();
971
972 router
973 .on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
974 .on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
975
976 assert_eq!(router.route("ping"), Some("pong".to_string()));
977 assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
978 assert_eq!(router.route("goodbye"), None);
979 }
980
981 #[test]
982 fn test_message_router_default() {
983 let router = MessageRouter::default();
984 assert_eq!(router.route("anything"), None);
986 }
987
988 #[test]
989 fn test_message_router_first_match_wins() {
990 let mut router = MessageRouter::new();
991 router
992 .on(MessagePattern::any(), |_| Some("first".to_string()))
993 .on(MessagePattern::any(), |_| Some("second".to_string()));
994
995 assert_eq!(router.route("test"), Some("first".to_string()));
996 }
997
998 #[test]
999 fn test_message_router_handler_returns_none() {
1000 let mut router = MessageRouter::new();
1001 router
1002 .on(MessagePattern::exact("skip"), |_| None)
1003 .on(MessagePattern::any(), |_| Some("fallback".to_string()));
1004
1005 assert_eq!(router.route("skip"), Some("fallback".to_string()));
1007 }
1008
1009 struct TestHandler;
1012
1013 #[async_trait]
1014 impl WsHandler for TestHandler {
1015 async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
1016 Ok(())
1017 }
1018 }
1019
1020 struct PathSpecificHandler {
1021 path: String,
1022 }
1023
1024 #[async_trait]
1025 impl WsHandler for PathSpecificHandler {
1026 async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
1027 Ok(())
1028 }
1029
1030 fn handles_path(&self, path: &str) -> bool {
1031 path == self.path
1032 }
1033 }
1034
1035 #[test]
1036 fn test_handler_registry_new() {
1037 let registry = HandlerRegistry::new();
1038 assert!(registry.is_empty());
1039 assert_eq!(registry.len(), 0);
1040 }
1041
1042 #[test]
1043 fn test_handler_registry_default() {
1044 let registry = HandlerRegistry::default();
1045 assert!(registry.is_empty());
1046 }
1047
1048 #[test]
1049 fn test_handler_registry_register() {
1050 let mut registry = HandlerRegistry::new();
1051 registry.register(TestHandler);
1052 assert!(!registry.is_empty());
1053 assert_eq!(registry.len(), 1);
1054 }
1055
1056 #[test]
1057 fn test_handler_registry_get_handlers() {
1058 let mut registry = HandlerRegistry::new();
1059 registry.register(TestHandler);
1060
1061 let handlers = registry.get_handlers("/any/path");
1062 assert_eq!(handlers.len(), 1);
1063 }
1064
1065 #[test]
1066 fn test_handler_registry_path_filtering() {
1067 let mut registry = HandlerRegistry::new();
1068 registry.register(PathSpecificHandler {
1069 path: "/ws/chat".to_string(),
1070 });
1071 registry.register(PathSpecificHandler {
1072 path: "/ws/events".to_string(),
1073 });
1074
1075 let chat_handlers = registry.get_handlers("/ws/chat");
1076 assert_eq!(chat_handlers.len(), 1);
1077
1078 let events_handlers = registry.get_handlers("/ws/events");
1079 assert_eq!(events_handlers.len(), 1);
1080
1081 let other_handlers = registry.get_handlers("/ws/other");
1082 assert!(other_handlers.is_empty());
1083 }
1084
1085 #[test]
1086 fn test_handler_registry_has_handler_for() {
1087 let mut registry = HandlerRegistry::new();
1088 registry.register(PathSpecificHandler {
1089 path: "/ws/chat".to_string(),
1090 });
1091
1092 assert!(registry.has_handler_for("/ws/chat"));
1093 assert!(!registry.has_handler_for("/ws/other"));
1094 }
1095
1096 #[test]
1097 fn test_handler_registry_clear() {
1098 let mut registry = HandlerRegistry::new();
1099 registry.register(TestHandler);
1100 registry.register(TestHandler);
1101 assert_eq!(registry.len(), 2);
1102
1103 registry.clear();
1104 assert!(registry.is_empty());
1105 }
1106
1107 #[test]
1108 fn test_handler_registry_with_hot_reload() {
1109 let registry = HandlerRegistry::with_hot_reload();
1110 assert!(registry.is_hot_reload_enabled());
1111 }
1112
1113 #[test]
1116 fn test_passthrough_config_new() {
1117 let config =
1118 PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1119 assert_eq!(config.upstream_url, "ws://upstream:8080");
1120 }
1121
1122 #[test]
1123 fn test_passthrough_config_regex() {
1124 let config =
1125 PassthroughConfig::regex(r"^forward", "ws://upstream:8080".to_string()).unwrap();
1126 assert!(config.pattern.matches("forward this"));
1127 assert!(!config.pattern.matches("don't forward"));
1128 }
1129
1130 #[test]
1131 fn test_passthrough_config_regex_invalid() {
1132 let result = PassthroughConfig::regex(r"[invalid", "ws://upstream:8080".to_string());
1133 assert!(result.is_err());
1134 }
1135
1136 #[test]
1139 fn test_passthrough_handler_should_passthrough() {
1140 let config =
1141 PassthroughConfig::regex(r"^proxy:", "ws://upstream:8080".to_string()).unwrap();
1142 let handler = PassthroughHandler::new(config);
1143
1144 assert!(handler.should_passthrough("proxy:hello"));
1145 assert!(!handler.should_passthrough("hello"));
1146 }
1147
1148 #[test]
1149 fn test_passthrough_handler_upstream_url() {
1150 let config =
1151 PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1152 let handler = PassthroughHandler::new(config);
1153
1154 assert_eq!(handler.upstream_url(), "ws://upstream:8080");
1155 }
1156
1157 #[test]
1160 fn test_handler_error_send_error() {
1161 let err = HandlerError::SendError("connection closed".to_string());
1162 assert!(err.to_string().contains("send message"));
1163 assert!(err.to_string().contains("connection closed"));
1164 }
1165
1166 #[test]
1167 fn test_handler_error_json_error() {
1168 let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
1169 let err = HandlerError::JsonError(json_err);
1170 assert!(err.to_string().contains("JSON"));
1171 }
1172
1173 #[test]
1174 fn test_handler_error_pattern_error() {
1175 let err = HandlerError::PatternError("invalid regex".to_string());
1176 assert!(err.to_string().contains("Pattern"));
1177 }
1178
1179 #[test]
1180 fn test_handler_error_room_error() {
1181 let err = HandlerError::RoomError("room full".to_string());
1182 assert!(err.to_string().contains("Room"));
1183 }
1184
1185 #[test]
1186 fn test_handler_error_connection_error() {
1187 let err = HandlerError::ConnectionError("timeout".to_string());
1188 assert!(err.to_string().contains("Connection"));
1189 }
1190
1191 #[test]
1192 fn test_handler_error_generic() {
1193 let err = HandlerError::Generic("something went wrong".to_string());
1194 assert!(err.to_string().contains("something went wrong"));
1195 }
1196
1197 #[tokio::test]
1200 async fn test_ws_context_metadata() {
1201 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1202 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1203
1204 ctx.set_metadata("user", serde_json::json!({"id": 1})).await;
1206 let value = ctx.get_metadata("user").await;
1207 assert_eq!(value, Some(serde_json::json!({"id": 1})));
1208
1209 let missing = ctx.get_metadata("nonexistent").await;
1211 assert!(missing.is_none());
1212 }
1213
1214 #[tokio::test]
1215 async fn test_ws_context_send_text() {
1216 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1217 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1218
1219 ctx.send_text("hello").await.unwrap();
1220
1221 let msg = rx.recv().await.unwrap();
1222 assert!(matches!(msg, Message::Text(_)));
1223 }
1224
1225 #[tokio::test]
1226 async fn test_ws_context_send_binary() {
1227 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1228 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1229
1230 ctx.send_binary(vec![1, 2, 3]).await.unwrap();
1231
1232 let msg = rx.recv().await.unwrap();
1233 assert!(matches!(msg, Message::Binary(_)));
1234 }
1235
1236 #[tokio::test]
1237 async fn test_ws_context_send_json() {
1238 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1239 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1240
1241 ctx.send_json(&serde_json::json!({"type": "test"})).await.unwrap();
1242
1243 let msg = rx.recv().await.unwrap();
1244 assert!(matches!(msg, Message::Text(_)));
1245 }
1246
1247 #[tokio::test]
1248 async fn test_ws_context_rooms() {
1249 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1250 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1251
1252 ctx.join_room("chat").await.unwrap();
1254 ctx.join_room("notifications").await.unwrap();
1255
1256 let rooms = ctx.get_rooms().await;
1257 assert_eq!(rooms.len(), 2);
1258
1259 ctx.leave_room("chat").await.unwrap();
1261 let rooms = ctx.get_rooms().await;
1262 assert_eq!(rooms.len(), 1);
1263 }
1264}