1use async_trait::async_trait;
73use axum::extract::ws::Message;
74use regex::Regex;
75use serde_json::Value;
76use std::collections::{HashMap, HashSet};
77use std::sync::Arc;
78use tokio::sync::{broadcast, RwLock};
79
80pub type HandlerResult<T> = Result<T, HandlerError>;
82
83#[derive(Debug, thiserror::Error)]
85pub enum HandlerError {
86 #[error("Failed to send message: {0}")]
88 SendError(String),
89
90 #[error("Failed to parse JSON: {0}")]
92 JsonError(#[from] serde_json::Error),
93
94 #[error("Pattern matching error: {0}")]
96 PatternError(String),
97
98 #[error("Room operation failed: {0}")]
100 RoomError(String),
101
102 #[error("Connection error: {0}")]
104 ConnectionError(String),
105
106 #[error("Handler error: {0}")]
108 Generic(String),
109}
110
111#[derive(Debug, Clone)]
113pub enum WsMessage {
114 Text(String),
116 Binary(Vec<u8>),
118 Ping(Vec<u8>),
120 Pong(Vec<u8>),
122 Close,
124}
125
126impl From<Message> for WsMessage {
127 fn from(msg: Message) -> Self {
128 match msg {
129 Message::Text(text) => WsMessage::Text(text.to_string()),
130 Message::Binary(data) => WsMessage::Binary(data.to_vec()),
131 Message::Ping(data) => WsMessage::Ping(data.to_vec()),
132 Message::Pong(data) => WsMessage::Pong(data.to_vec()),
133 Message::Close(_) => WsMessage::Close,
134 }
135 }
136}
137
138impl From<WsMessage> for Message {
139 fn from(msg: WsMessage) -> Self {
140 match msg {
141 WsMessage::Text(text) => Message::Text(text.into()),
142 WsMessage::Binary(data) => Message::Binary(data.into()),
143 WsMessage::Ping(data) => Message::Ping(data.into()),
144 WsMessage::Pong(data) => Message::Pong(data.into()),
145 WsMessage::Close => Message::Close(None),
146 }
147 }
148}
149
150#[derive(Debug, Clone)]
152pub enum MessagePattern {
153 Regex(Regex),
155 JsonPath(String),
157 Exact(String),
159 Any,
161}
162
163impl MessagePattern {
164 pub fn regex(pattern: &str) -> HandlerResult<Self> {
166 Ok(MessagePattern::Regex(
167 Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
168 ))
169 }
170
171 pub fn jsonpath(query: &str) -> Self {
173 MessagePattern::JsonPath(query.to_string())
174 }
175
176 pub fn exact(text: &str) -> Self {
178 MessagePattern::Exact(text.to_string())
179 }
180
181 pub fn any() -> Self {
183 MessagePattern::Any
184 }
185
186 pub fn matches(&self, text: &str) -> bool {
188 match self {
189 MessagePattern::Regex(re) => re.is_match(text),
190 MessagePattern::JsonPath(query) => {
191 if let Ok(json) = serde_json::from_str::<Value>(text) {
193 if let Ok(selector) = jsonpath::Selector::new(query) {
195 let results: Vec<_> = selector.find(&json).collect();
196 !results.is_empty()
197 } else {
198 false
199 }
200 } else {
201 false
202 }
203 }
204 MessagePattern::Exact(expected) => text == expected,
205 MessagePattern::Any => true,
206 }
207 }
208
209 pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
211 if let Ok(json) = serde_json::from_str::<Value>(text) {
212 if let Ok(selector) = jsonpath::Selector::new(query) {
213 let results: Vec<_> = selector.find(&json).collect();
214 results.first().cloned().cloned()
215 } else {
216 None
217 }
218 } else {
219 None
220 }
221 }
222}
223
224pub type ConnectionId = String;
226
227#[derive(Clone)]
229pub struct RoomManager {
230 rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
231 connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
232 broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
233}
234
235impl RoomManager {
236 pub fn new() -> Self {
238 Self {
239 rooms: Arc::new(RwLock::new(HashMap::new())),
240 connections: Arc::new(RwLock::new(HashMap::new())),
241 broadcasters: Arc::new(RwLock::new(HashMap::new())),
242 }
243 }
244
245 pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
247 let mut rooms = self.rooms.write().await;
248 let mut connections = self.connections.write().await;
249
250 rooms
251 .entry(room.to_string())
252 .or_insert_with(HashSet::new)
253 .insert(conn_id.to_string());
254
255 connections
256 .entry(conn_id.to_string())
257 .or_insert_with(HashSet::new)
258 .insert(room.to_string());
259
260 Ok(())
261 }
262
263 pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
265 let mut rooms = self.rooms.write().await;
266 let mut connections = self.connections.write().await;
267
268 if let Some(room_members) = rooms.get_mut(room) {
269 room_members.remove(conn_id);
270 if room_members.is_empty() {
271 rooms.remove(room);
272 }
273 }
274
275 if let Some(conn_rooms) = connections.get_mut(conn_id) {
276 conn_rooms.remove(room);
277 if conn_rooms.is_empty() {
278 connections.remove(conn_id);
279 }
280 }
281
282 Ok(())
283 }
284
285 pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
287 let mut connections = self.connections.write().await;
288 if let Some(conn_rooms) = connections.remove(conn_id) {
289 let mut rooms = self.rooms.write().await;
290 for room in conn_rooms {
291 if let Some(room_members) = rooms.get_mut(&room) {
292 room_members.remove(conn_id);
293 if room_members.is_empty() {
294 rooms.remove(&room);
295 }
296 }
297 }
298 }
299 Ok(())
300 }
301
302 pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
304 let rooms = self.rooms.read().await;
305 rooms
306 .get(room)
307 .map(|members| members.iter().cloned().collect())
308 .unwrap_or_default()
309 }
310
311 pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
313 let connections = self.connections.read().await;
314 connections
315 .get(conn_id)
316 .map(|rooms| rooms.iter().cloned().collect())
317 .unwrap_or_default()
318 }
319
320 pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
322 let mut broadcasters = self.broadcasters.write().await;
323 broadcasters
324 .entry(room.to_string())
325 .or_insert_with(|| {
326 let (tx, _) = broadcast::channel(1024);
327 tx
328 })
329 .clone()
330 }
331}
332
333impl Default for RoomManager {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339pub struct WsContext {
341 pub connection_id: ConnectionId,
343 pub path: String,
345 room_manager: RoomManager,
347 message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
349 metadata: Arc<RwLock<HashMap<String, Value>>>,
351}
352
353impl WsContext {
354 pub fn new(
356 connection_id: ConnectionId,
357 path: String,
358 room_manager: RoomManager,
359 message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
360 ) -> Self {
361 Self {
362 connection_id,
363 path,
364 room_manager,
365 message_tx,
366 metadata: Arc::new(RwLock::new(HashMap::new())),
367 }
368 }
369
370 pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
372 self.message_tx
373 .send(Message::Text(text.to_string().into()))
374 .map_err(|e| HandlerError::SendError(e.to_string()))
375 }
376
377 pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
379 self.message_tx
380 .send(Message::Binary(data.into()))
381 .map_err(|e| HandlerError::SendError(e.to_string()))
382 }
383
384 pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
386 let text = serde_json::to_string(value)?;
387 self.send_text(&text).await
388 }
389
390 pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
392 self.room_manager.join(&self.connection_id, room).await
393 }
394
395 pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
397 self.room_manager.leave(&self.connection_id, room).await
398 }
399
400 pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
402 let broadcaster = self.room_manager.get_broadcaster(room).await;
403 broadcaster
404 .send(text.to_string())
405 .map_err(|e| HandlerError::RoomError(e.to_string()))?;
406 Ok(())
407 }
408
409 pub async fn get_rooms(&self) -> Vec<String> {
411 self.room_manager.get_connection_rooms(&self.connection_id).await
412 }
413
414 pub async fn set_metadata(&self, key: &str, value: Value) {
416 let mut metadata = self.metadata.write().await;
417 metadata.insert(key.to_string(), value);
418 }
419
420 pub async fn get_metadata(&self, key: &str) -> Option<Value> {
422 let metadata = self.metadata.read().await;
423 metadata.get(key).cloned()
424 }
425}
426
427#[async_trait]
429pub trait WsHandler: Send + Sync {
430 async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
432 Ok(())
433 }
434
435 async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
437
438 async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
440 Ok(())
441 }
442
443 fn handles_path(&self, _path: &str) -> bool {
445 true }
447}
448
449pub struct MessageRouter {
451 routes: Vec<(MessagePattern, Box<dyn Fn(String) -> Option<String> + Send + Sync>)>,
452}
453
454impl MessageRouter {
455 pub fn new() -> Self {
457 Self { routes: Vec::new() }
458 }
459
460 pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
462 where
463 F: Fn(String) -> Option<String> + Send + Sync + 'static,
464 {
465 self.routes.push((pattern, Box::new(handler)));
466 self
467 }
468
469 pub fn route(&self, text: &str) -> Option<String> {
471 for (pattern, handler) in &self.routes {
472 if pattern.matches(text) {
473 if let Some(response) = handler(text.to_string()) {
474 return Some(response);
475 }
476 }
477 }
478 None
479 }
480}
481
482impl Default for MessageRouter {
483 fn default() -> Self {
484 Self::new()
485 }
486}
487
488pub struct HandlerRegistry {
490 handlers: Vec<Arc<dyn WsHandler>>,
491 hot_reload_enabled: bool,
492}
493
494impl HandlerRegistry {
495 pub fn new() -> Self {
497 Self {
498 handlers: Vec::new(),
499 hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
500 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
501 .unwrap_or(false),
502 }
503 }
504
505 pub fn with_hot_reload() -> Self {
507 Self {
508 handlers: Vec::new(),
509 hot_reload_enabled: true,
510 }
511 }
512
513 pub fn is_hot_reload_enabled(&self) -> bool {
515 self.hot_reload_enabled
516 }
517
518 pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
520 self.handlers.push(Arc::new(handler));
521 self
522 }
523
524 pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
526 self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
527 }
528
529 pub fn has_handler_for(&self, path: &str) -> bool {
531 self.handlers.iter().any(|h| h.handles_path(path))
532 }
533
534 pub fn clear(&mut self) {
536 self.handlers.clear();
537 }
538
539 pub fn len(&self) -> usize {
541 self.handlers.len()
542 }
543
544 pub fn is_empty(&self) -> bool {
546 self.handlers.is_empty()
547 }
548}
549
550impl Default for HandlerRegistry {
551 fn default() -> Self {
552 Self::new()
553 }
554}
555
556#[derive(Clone)]
558pub struct PassthroughConfig {
559 pub pattern: MessagePattern,
561 pub upstream_url: String,
563}
564
565impl PassthroughConfig {
566 pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
568 Self {
569 pattern,
570 upstream_url,
571 }
572 }
573
574 pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
576 Ok(Self {
577 pattern: MessagePattern::regex(regex)?,
578 upstream_url,
579 })
580 }
581}
582
583pub struct PassthroughHandler {
585 config: PassthroughConfig,
586}
587
588impl PassthroughHandler {
589 pub fn new(config: PassthroughConfig) -> Self {
591 Self { config }
592 }
593
594 pub fn should_passthrough(&self, text: &str) -> bool {
596 self.config.pattern.matches(text)
597 }
598
599 pub fn upstream_url(&self) -> &str {
601 &self.config.upstream_url
602 }
603}
604
605#[async_trait]
606impl WsHandler for PassthroughHandler {
607 async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
608 if let WsMessage::Text(text) = &msg {
609 if self.should_passthrough(text) {
610 ctx.send_text(&format!("PASSTHROUGH({}): {}", self.config.upstream_url, text))
613 .await?;
614 return Ok(());
615 }
616 }
617 Ok(())
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624
625 #[test]
628 fn test_ws_message_text_from_axum() {
629 let axum_msg = Message::Text("hello".to_string().into());
630 let ws_msg: WsMessage = axum_msg.into();
631 match ws_msg {
632 WsMessage::Text(text) => assert_eq!(text, "hello"),
633 _ => panic!("Expected Text message"),
634 }
635 }
636
637 #[test]
638 fn test_ws_message_binary_from_axum() {
639 let data = vec![1, 2, 3, 4];
640 let axum_msg = Message::Binary(data.clone().into());
641 let ws_msg: WsMessage = axum_msg.into();
642 match ws_msg {
643 WsMessage::Binary(bytes) => assert_eq!(bytes, data),
644 _ => panic!("Expected Binary message"),
645 }
646 }
647
648 #[test]
649 fn test_ws_message_ping_from_axum() {
650 let data = vec![1, 2];
651 let axum_msg = Message::Ping(data.clone().into());
652 let ws_msg: WsMessage = axum_msg.into();
653 match ws_msg {
654 WsMessage::Ping(bytes) => assert_eq!(bytes, data),
655 _ => panic!("Expected Ping message"),
656 }
657 }
658
659 #[test]
660 fn test_ws_message_pong_from_axum() {
661 let data = vec![3, 4];
662 let axum_msg = Message::Pong(data.clone().into());
663 let ws_msg: WsMessage = axum_msg.into();
664 match ws_msg {
665 WsMessage::Pong(bytes) => assert_eq!(bytes, data),
666 _ => panic!("Expected Pong message"),
667 }
668 }
669
670 #[test]
671 fn test_ws_message_close_from_axum() {
672 let axum_msg = Message::Close(None);
673 let ws_msg: WsMessage = axum_msg.into();
674 assert!(matches!(ws_msg, WsMessage::Close));
675 }
676
677 #[test]
678 fn test_ws_message_text_to_axum() {
679 let ws_msg = WsMessage::Text("hello".to_string());
680 let axum_msg: Message = ws_msg.into();
681 assert!(matches!(axum_msg, Message::Text(_)));
682 }
683
684 #[test]
685 fn test_ws_message_binary_to_axum() {
686 let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
687 let axum_msg: Message = ws_msg.into();
688 assert!(matches!(axum_msg, Message::Binary(_)));
689 }
690
691 #[test]
692 fn test_ws_message_close_to_axum() {
693 let ws_msg = WsMessage::Close;
694 let axum_msg: Message = ws_msg.into();
695 assert!(matches!(axum_msg, Message::Close(_)));
696 }
697
698 #[test]
701 fn test_message_pattern_regex() {
702 let pattern = MessagePattern::regex(r"^hello").unwrap();
703 assert!(pattern.matches("hello world"));
704 assert!(!pattern.matches("goodbye world"));
705 }
706
707 #[test]
708 fn test_message_pattern_regex_invalid() {
709 let result = MessagePattern::regex(r"[invalid");
710 assert!(result.is_err());
711 }
712
713 #[test]
714 fn test_message_pattern_exact() {
715 let pattern = MessagePattern::exact("hello");
716 assert!(pattern.matches("hello"));
717 assert!(!pattern.matches("hello world"));
718 }
719
720 #[test]
721 fn test_message_pattern_jsonpath() {
722 let pattern = MessagePattern::jsonpath("$.type");
723 assert!(pattern.matches(r#"{"type": "message"}"#));
724 assert!(!pattern.matches(r#"{"name": "test"}"#));
725 }
726
727 #[test]
728 fn test_message_pattern_jsonpath_nested() {
729 let pattern = MessagePattern::jsonpath("$.user.name");
730 assert!(pattern.matches(r#"{"user": {"name": "John"}}"#));
731 assert!(!pattern.matches(r#"{"user": {"email": "john@example.com"}}"#));
732 }
733
734 #[test]
735 fn test_message_pattern_jsonpath_invalid_json() {
736 let pattern = MessagePattern::jsonpath("$.type");
737 assert!(!pattern.matches("not json"));
738 }
739
740 #[test]
741 fn test_message_pattern_any() {
742 let pattern = MessagePattern::any();
743 assert!(pattern.matches("anything"));
744 assert!(pattern.matches(""));
745 assert!(pattern.matches(r#"{"json": true}"#));
746 }
747
748 #[test]
749 fn test_message_pattern_extract() {
750 let pattern = MessagePattern::jsonpath("$.type");
751 let result = pattern.extract(r#"{"type": "greeting", "data": "hello"}"#, "$.type");
752 assert_eq!(result, Some(serde_json::json!("greeting")));
753 }
754
755 #[test]
756 fn test_message_pattern_extract_nested() {
757 let pattern = MessagePattern::any();
758 let result = pattern.extract(r#"{"user": {"id": 123}}"#, "$.user.id");
759 assert_eq!(result, Some(serde_json::json!(123)));
760 }
761
762 #[test]
763 fn test_message_pattern_extract_not_found() {
764 let pattern = MessagePattern::any();
765 let result = pattern.extract(r#"{"type": "message"}"#, "$.nonexistent");
766 assert!(result.is_none());
767 }
768
769 #[test]
770 fn test_message_pattern_extract_invalid_json() {
771 let pattern = MessagePattern::any();
772 let result = pattern.extract("not json", "$.type");
773 assert!(result.is_none());
774 }
775
776 #[tokio::test]
779 async fn test_room_manager() {
780 let manager = RoomManager::new();
781
782 manager.join("conn1", "room1").await.unwrap();
784 manager.join("conn1", "room2").await.unwrap();
785 manager.join("conn2", "room1").await.unwrap();
786
787 let room1_members = manager.get_room_members("room1").await;
789 assert_eq!(room1_members.len(), 2);
790 assert!(room1_members.contains(&"conn1".to_string()));
791 assert!(room1_members.contains(&"conn2".to_string()));
792
793 let conn1_rooms = manager.get_connection_rooms("conn1").await;
795 assert_eq!(conn1_rooms.len(), 2);
796 assert!(conn1_rooms.contains(&"room1".to_string()));
797 assert!(conn1_rooms.contains(&"room2".to_string()));
798
799 manager.leave("conn1", "room1").await.unwrap();
801 let room1_members = manager.get_room_members("room1").await;
802 assert_eq!(room1_members.len(), 1);
803 assert!(room1_members.contains(&"conn2".to_string()));
804
805 manager.leave_all("conn1").await.unwrap();
807 let conn1_rooms = manager.get_connection_rooms("conn1").await;
808 assert_eq!(conn1_rooms.len(), 0);
809 }
810
811 #[tokio::test]
812 async fn test_room_manager_default() {
813 let manager = RoomManager::default();
814 manager.join("conn1", "room1").await.unwrap();
816 let members = manager.get_room_members("room1").await;
817 assert_eq!(members.len(), 1);
818 }
819
820 #[tokio::test]
821 async fn test_room_manager_empty_room() {
822 let manager = RoomManager::new();
823 let members = manager.get_room_members("nonexistent").await;
824 assert!(members.is_empty());
825 }
826
827 #[tokio::test]
828 async fn test_room_manager_empty_connection() {
829 let manager = RoomManager::new();
830 let rooms = manager.get_connection_rooms("nonexistent").await;
831 assert!(rooms.is_empty());
832 }
833
834 #[tokio::test]
835 async fn test_room_manager_leave_nonexistent() {
836 let manager = RoomManager::new();
837 let result = manager.leave("conn1", "room1").await;
839 assert!(result.is_ok());
840 }
841
842 #[tokio::test]
843 async fn test_room_manager_broadcaster() {
844 let manager = RoomManager::new();
845 manager.join("conn1", "room1").await.unwrap();
846
847 let broadcaster = manager.get_broadcaster("room1").await;
848 let mut receiver = broadcaster.subscribe();
849
850 broadcaster.send("hello".to_string()).unwrap();
852
853 let msg = receiver.recv().await.unwrap();
855 assert_eq!(msg, "hello");
856 }
857
858 #[tokio::test]
859 async fn test_room_manager_room_cleanup_on_last_leave() {
860 let manager = RoomManager::new();
861 manager.join("conn1", "room1").await.unwrap();
862 manager.leave("conn1", "room1").await.unwrap();
863
864 let members = manager.get_room_members("room1").await;
866 assert!(members.is_empty());
867 }
868
869 #[test]
872 fn test_message_router() {
873 let mut router = MessageRouter::new();
874
875 router
876 .on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
877 .on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
878
879 assert_eq!(router.route("ping"), Some("pong".to_string()));
880 assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
881 assert_eq!(router.route("goodbye"), None);
882 }
883
884 #[test]
885 fn test_message_router_default() {
886 let router = MessageRouter::default();
887 assert_eq!(router.route("anything"), None);
889 }
890
891 #[test]
892 fn test_message_router_first_match_wins() {
893 let mut router = MessageRouter::new();
894 router
895 .on(MessagePattern::any(), |_| Some("first".to_string()))
896 .on(MessagePattern::any(), |_| Some("second".to_string()));
897
898 assert_eq!(router.route("test"), Some("first".to_string()));
899 }
900
901 #[test]
902 fn test_message_router_handler_returns_none() {
903 let mut router = MessageRouter::new();
904 router
905 .on(MessagePattern::exact("skip"), |_| None)
906 .on(MessagePattern::any(), |_| Some("fallback".to_string()));
907
908 assert_eq!(router.route("skip"), Some("fallback".to_string()));
910 }
911
912 struct TestHandler;
915
916 #[async_trait]
917 impl WsHandler for TestHandler {
918 async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
919 Ok(())
920 }
921 }
922
923 struct PathSpecificHandler {
924 path: String,
925 }
926
927 #[async_trait]
928 impl WsHandler for PathSpecificHandler {
929 async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
930 Ok(())
931 }
932
933 fn handles_path(&self, path: &str) -> bool {
934 path == self.path
935 }
936 }
937
938 #[test]
939 fn test_handler_registry_new() {
940 let registry = HandlerRegistry::new();
941 assert!(registry.is_empty());
942 assert_eq!(registry.len(), 0);
943 }
944
945 #[test]
946 fn test_handler_registry_default() {
947 let registry = HandlerRegistry::default();
948 assert!(registry.is_empty());
949 }
950
951 #[test]
952 fn test_handler_registry_register() {
953 let mut registry = HandlerRegistry::new();
954 registry.register(TestHandler);
955 assert!(!registry.is_empty());
956 assert_eq!(registry.len(), 1);
957 }
958
959 #[test]
960 fn test_handler_registry_get_handlers() {
961 let mut registry = HandlerRegistry::new();
962 registry.register(TestHandler);
963
964 let handlers = registry.get_handlers("/any/path");
965 assert_eq!(handlers.len(), 1);
966 }
967
968 #[test]
969 fn test_handler_registry_path_filtering() {
970 let mut registry = HandlerRegistry::new();
971 registry.register(PathSpecificHandler {
972 path: "/ws/chat".to_string(),
973 });
974 registry.register(PathSpecificHandler {
975 path: "/ws/events".to_string(),
976 });
977
978 let chat_handlers = registry.get_handlers("/ws/chat");
979 assert_eq!(chat_handlers.len(), 1);
980
981 let events_handlers = registry.get_handlers("/ws/events");
982 assert_eq!(events_handlers.len(), 1);
983
984 let other_handlers = registry.get_handlers("/ws/other");
985 assert!(other_handlers.is_empty());
986 }
987
988 #[test]
989 fn test_handler_registry_has_handler_for() {
990 let mut registry = HandlerRegistry::new();
991 registry.register(PathSpecificHandler {
992 path: "/ws/chat".to_string(),
993 });
994
995 assert!(registry.has_handler_for("/ws/chat"));
996 assert!(!registry.has_handler_for("/ws/other"));
997 }
998
999 #[test]
1000 fn test_handler_registry_clear() {
1001 let mut registry = HandlerRegistry::new();
1002 registry.register(TestHandler);
1003 registry.register(TestHandler);
1004 assert_eq!(registry.len(), 2);
1005
1006 registry.clear();
1007 assert!(registry.is_empty());
1008 }
1009
1010 #[test]
1011 fn test_handler_registry_with_hot_reload() {
1012 let registry = HandlerRegistry::with_hot_reload();
1013 assert!(registry.is_hot_reload_enabled());
1014 }
1015
1016 #[test]
1019 fn test_passthrough_config_new() {
1020 let config =
1021 PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1022 assert_eq!(config.upstream_url, "ws://upstream:8080");
1023 }
1024
1025 #[test]
1026 fn test_passthrough_config_regex() {
1027 let config =
1028 PassthroughConfig::regex(r"^forward", "ws://upstream:8080".to_string()).unwrap();
1029 assert!(config.pattern.matches("forward this"));
1030 assert!(!config.pattern.matches("don't forward"));
1031 }
1032
1033 #[test]
1034 fn test_passthrough_config_regex_invalid() {
1035 let result = PassthroughConfig::regex(r"[invalid", "ws://upstream:8080".to_string());
1036 assert!(result.is_err());
1037 }
1038
1039 #[test]
1042 fn test_passthrough_handler_should_passthrough() {
1043 let config =
1044 PassthroughConfig::regex(r"^proxy:", "ws://upstream:8080".to_string()).unwrap();
1045 let handler = PassthroughHandler::new(config);
1046
1047 assert!(handler.should_passthrough("proxy:hello"));
1048 assert!(!handler.should_passthrough("hello"));
1049 }
1050
1051 #[test]
1052 fn test_passthrough_handler_upstream_url() {
1053 let config =
1054 PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1055 let handler = PassthroughHandler::new(config);
1056
1057 assert_eq!(handler.upstream_url(), "ws://upstream:8080");
1058 }
1059
1060 #[test]
1063 fn test_handler_error_send_error() {
1064 let err = HandlerError::SendError("connection closed".to_string());
1065 assert!(err.to_string().contains("send message"));
1066 assert!(err.to_string().contains("connection closed"));
1067 }
1068
1069 #[test]
1070 fn test_handler_error_json_error() {
1071 let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
1072 let err = HandlerError::JsonError(json_err);
1073 assert!(err.to_string().contains("JSON"));
1074 }
1075
1076 #[test]
1077 fn test_handler_error_pattern_error() {
1078 let err = HandlerError::PatternError("invalid regex".to_string());
1079 assert!(err.to_string().contains("Pattern"));
1080 }
1081
1082 #[test]
1083 fn test_handler_error_room_error() {
1084 let err = HandlerError::RoomError("room full".to_string());
1085 assert!(err.to_string().contains("Room"));
1086 }
1087
1088 #[test]
1089 fn test_handler_error_connection_error() {
1090 let err = HandlerError::ConnectionError("timeout".to_string());
1091 assert!(err.to_string().contains("Connection"));
1092 }
1093
1094 #[test]
1095 fn test_handler_error_generic() {
1096 let err = HandlerError::Generic("something went wrong".to_string());
1097 assert!(err.to_string().contains("something went wrong"));
1098 }
1099
1100 #[tokio::test]
1103 async fn test_ws_context_metadata() {
1104 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1105 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1106
1107 ctx.set_metadata("user", serde_json::json!({"id": 1})).await;
1109 let value = ctx.get_metadata("user").await;
1110 assert_eq!(value, Some(serde_json::json!({"id": 1})));
1111
1112 let missing = ctx.get_metadata("nonexistent").await;
1114 assert!(missing.is_none());
1115 }
1116
1117 #[tokio::test]
1118 async fn test_ws_context_send_text() {
1119 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1120 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1121
1122 ctx.send_text("hello").await.unwrap();
1123
1124 let msg = rx.recv().await.unwrap();
1125 assert!(matches!(msg, Message::Text(_)));
1126 }
1127
1128 #[tokio::test]
1129 async fn test_ws_context_send_binary() {
1130 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1131 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1132
1133 ctx.send_binary(vec![1, 2, 3]).await.unwrap();
1134
1135 let msg = rx.recv().await.unwrap();
1136 assert!(matches!(msg, Message::Binary(_)));
1137 }
1138
1139 #[tokio::test]
1140 async fn test_ws_context_send_json() {
1141 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1142 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1143
1144 ctx.send_json(&serde_json::json!({"type": "test"})).await.unwrap();
1145
1146 let msg = rx.recv().await.unwrap();
1147 assert!(matches!(msg, Message::Text(_)));
1148 }
1149
1150 #[tokio::test]
1151 async fn test_ws_context_rooms() {
1152 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1153 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1154
1155 ctx.join_room("chat").await.unwrap();
1157 ctx.join_room("notifications").await.unwrap();
1158
1159 let rooms = ctx.get_rooms().await;
1160 assert_eq!(rooms.len(), 2);
1161
1162 ctx.leave_room("chat").await.unwrap();
1164 let rooms = ctx.get_rooms().await;
1165 assert_eq!(rooms.len(), 1);
1166 }
1167}