1use async_trait::async_trait;
73use axum::extract::ws::Message;
74use futures_util::{SinkExt, StreamExt};
75use regex::Regex;
76use serde_json::Value;
77use std::collections::{HashMap, HashSet};
78use std::sync::Arc;
79use tokio::sync::{broadcast, Mutex, RwLock};
80
81pub type HandlerResult<T> = Result<T, HandlerError>;
83
84#[derive(Debug, thiserror::Error)]
86pub enum HandlerError {
87 #[error("Failed to send message: {0}")]
89 SendError(String),
90
91 #[error("Failed to parse JSON: {0}")]
93 JsonError(#[from] serde_json::Error),
94
95 #[error("Pattern matching error: {0}")]
97 PatternError(String),
98
99 #[error("Room operation failed: {0}")]
101 RoomError(String),
102
103 #[error("Connection error: {0}")]
105 ConnectionError(String),
106
107 #[error("Handler error: {0}")]
109 Generic(String),
110}
111
112#[derive(Debug, Clone)]
114pub enum WsMessage {
115 Text(String),
117 Binary(Vec<u8>),
119 Ping(Vec<u8>),
121 Pong(Vec<u8>),
123 Close,
125}
126
127impl From<Message> for WsMessage {
128 fn from(msg: Message) -> Self {
129 match msg {
130 Message::Text(text) => WsMessage::Text(text.to_string()),
131 Message::Binary(data) => WsMessage::Binary(data.to_vec()),
132 Message::Ping(data) => WsMessage::Ping(data.to_vec()),
133 Message::Pong(data) => WsMessage::Pong(data.to_vec()),
134 Message::Close(_) => WsMessage::Close,
135 }
136 }
137}
138
139impl From<WsMessage> for Message {
140 fn from(msg: WsMessage) -> Self {
141 match msg {
142 WsMessage::Text(text) => Message::Text(text.into()),
143 WsMessage::Binary(data) => Message::Binary(data.into()),
144 WsMessage::Ping(data) => Message::Ping(data.into()),
145 WsMessage::Pong(data) => Message::Pong(data.into()),
146 WsMessage::Close => Message::Close(None),
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
153pub enum MessagePattern {
154 Regex(Regex),
156 JsonPath(String),
158 Exact(String),
160 Any,
162}
163
164impl MessagePattern {
165 pub fn regex(pattern: &str) -> HandlerResult<Self> {
167 Ok(MessagePattern::Regex(
168 Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
169 ))
170 }
171
172 pub fn jsonpath(query: &str) -> Self {
174 MessagePattern::JsonPath(query.to_string())
175 }
176
177 pub fn exact(text: &str) -> Self {
179 MessagePattern::Exact(text.to_string())
180 }
181
182 pub fn any() -> Self {
184 MessagePattern::Any
185 }
186
187 pub fn matches(&self, text: &str) -> bool {
189 match self {
190 MessagePattern::Regex(re) => re.is_match(text),
191 MessagePattern::JsonPath(query) => {
192 if let Ok(json) = serde_json::from_str::<Value>(text) {
194 if let Ok(selector) = jsonpath::Selector::new(query) {
196 let results: Vec<_> = selector.find(&json).collect();
197 !results.is_empty()
198 } else {
199 false
200 }
201 } else {
202 false
203 }
204 }
205 MessagePattern::Exact(expected) => text == expected,
206 MessagePattern::Any => true,
207 }
208 }
209
210 pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
212 if let Ok(json) = serde_json::from_str::<Value>(text) {
213 if let Ok(selector) = jsonpath::Selector::new(query) {
214 let results: Vec<_> = selector.find(&json).collect();
215 results.first().cloned().cloned()
216 } else {
217 None
218 }
219 } else {
220 None
221 }
222 }
223}
224
225pub type ConnectionId = String;
227
228#[derive(Clone)]
230pub struct RoomManager {
231 rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
232 connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
233 broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
234}
235
236impl RoomManager {
237 pub fn new() -> Self {
239 Self {
240 rooms: Arc::new(RwLock::new(HashMap::new())),
241 connections: Arc::new(RwLock::new(HashMap::new())),
242 broadcasters: Arc::new(RwLock::new(HashMap::new())),
243 }
244 }
245
246 pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
248 let mut rooms = self.rooms.write().await;
249 let mut connections = self.connections.write().await;
250
251 rooms
252 .entry(room.to_string())
253 .or_insert_with(HashSet::new)
254 .insert(conn_id.to_string());
255
256 connections
257 .entry(conn_id.to_string())
258 .or_insert_with(HashSet::new)
259 .insert(room.to_string());
260
261 Ok(())
262 }
263
264 pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
266 let mut rooms = self.rooms.write().await;
267 let mut connections = self.connections.write().await;
268
269 if let Some(room_members) = rooms.get_mut(room) {
270 room_members.remove(conn_id);
271 if room_members.is_empty() {
272 rooms.remove(room);
273 }
274 }
275
276 if let Some(conn_rooms) = connections.get_mut(conn_id) {
277 conn_rooms.remove(room);
278 if conn_rooms.is_empty() {
279 connections.remove(conn_id);
280 }
281 }
282
283 Ok(())
284 }
285
286 pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
288 let mut connections = self.connections.write().await;
289 if let Some(conn_rooms) = connections.remove(conn_id) {
290 let mut rooms = self.rooms.write().await;
291 for room in conn_rooms {
292 if let Some(room_members) = rooms.get_mut(&room) {
293 room_members.remove(conn_id);
294 if room_members.is_empty() {
295 rooms.remove(&room);
296 }
297 }
298 }
299 }
300 Ok(())
301 }
302
303 pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
305 let rooms = self.rooms.read().await;
306 rooms
307 .get(room)
308 .map(|members| members.iter().cloned().collect())
309 .unwrap_or_default()
310 }
311
312 pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
314 let connections = self.connections.read().await;
315 connections
316 .get(conn_id)
317 .map(|rooms| rooms.iter().cloned().collect())
318 .unwrap_or_default()
319 }
320
321 pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
323 let mut broadcasters = self.broadcasters.write().await;
324 broadcasters
325 .entry(room.to_string())
326 .or_insert_with(|| {
327 let (tx, _) = broadcast::channel(1024);
328 tx
329 })
330 .clone()
331 }
332}
333
334impl Default for RoomManager {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340pub struct WsContext {
342 pub connection_id: ConnectionId,
344 pub path: String,
346 room_manager: RoomManager,
348 message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
350 metadata: Arc<RwLock<HashMap<String, Value>>>,
352}
353
354impl WsContext {
355 pub fn new(
357 connection_id: ConnectionId,
358 path: String,
359 room_manager: RoomManager,
360 message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
361 ) -> Self {
362 Self {
363 connection_id,
364 path,
365 room_manager,
366 message_tx,
367 metadata: Arc::new(RwLock::new(HashMap::new())),
368 }
369 }
370
371 pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
373 self.message_tx
374 .send(Message::Text(text.to_string().into()))
375 .map_err(|e| HandlerError::SendError(e.to_string()))
376 }
377
378 pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
380 self.message_tx
381 .send(Message::Binary(data.into()))
382 .map_err(|e| HandlerError::SendError(e.to_string()))
383 }
384
385 pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
387 let text = serde_json::to_string(value)?;
388 self.send_text(&text).await
389 }
390
391 pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
393 self.room_manager.join(&self.connection_id, room).await
394 }
395
396 pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
398 self.room_manager.leave(&self.connection_id, room).await
399 }
400
401 pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
403 let broadcaster = self.room_manager.get_broadcaster(room).await;
404 broadcaster
405 .send(text.to_string())
406 .map_err(|e| HandlerError::RoomError(e.to_string()))?;
407 Ok(())
408 }
409
410 pub async fn get_rooms(&self) -> Vec<String> {
412 self.room_manager.get_connection_rooms(&self.connection_id).await
413 }
414
415 pub async fn set_metadata(&self, key: &str, value: Value) {
417 let mut metadata = self.metadata.write().await;
418 metadata.insert(key.to_string(), value);
419 }
420
421 pub async fn get_metadata(&self, key: &str) -> Option<Value> {
423 let metadata = self.metadata.read().await;
424 metadata.get(key).cloned()
425 }
426}
427
428#[async_trait]
430pub trait WsHandler: Send + Sync {
431 async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
433 Ok(())
434 }
435
436 async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
438
439 async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
441 Ok(())
442 }
443
444 fn handles_path(&self, _path: &str) -> bool {
446 true }
448}
449
450pub struct MessageRouter {
452 routes: Vec<(MessagePattern, Box<dyn Fn(String) -> Option<String> + Send + Sync>)>,
453}
454
455impl MessageRouter {
456 pub fn new() -> Self {
458 Self { routes: Vec::new() }
459 }
460
461 pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
463 where
464 F: Fn(String) -> Option<String> + Send + Sync + 'static,
465 {
466 self.routes.push((pattern, Box::new(handler)));
467 self
468 }
469
470 pub fn route(&self, text: &str) -> Option<String> {
472 for (pattern, handler) in &self.routes {
473 if pattern.matches(text) {
474 if let Some(response) = handler(text.to_string()) {
475 return Some(response);
476 }
477 }
478 }
479 None
480 }
481}
482
483impl Default for MessageRouter {
484 fn default() -> Self {
485 Self::new()
486 }
487}
488
489pub struct HandlerRegistry {
491 handlers: Vec<Arc<dyn WsHandler>>,
492 hot_reload_enabled: bool,
493}
494
495impl HandlerRegistry {
496 pub fn new() -> Self {
498 Self {
499 handlers: Vec::new(),
500 hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
501 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
502 .unwrap_or(false),
503 }
504 }
505
506 pub fn with_hot_reload() -> Self {
508 Self {
509 handlers: Vec::new(),
510 hot_reload_enabled: true,
511 }
512 }
513
514 pub fn is_hot_reload_enabled(&self) -> bool {
516 self.hot_reload_enabled
517 }
518
519 pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
521 self.handlers.push(Arc::new(handler));
522 self
523 }
524
525 pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
527 self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
528 }
529
530 pub fn has_handler_for(&self, path: &str) -> bool {
532 self.handlers.iter().any(|h| h.handles_path(path))
533 }
534
535 pub fn clear(&mut self) {
537 self.handlers.clear();
538 }
539
540 pub fn len(&self) -> usize {
542 self.handlers.len()
543 }
544
545 pub fn is_empty(&self) -> bool {
547 self.handlers.is_empty()
548 }
549}
550
551impl Default for HandlerRegistry {
552 fn default() -> Self {
553 Self::new()
554 }
555}
556
557#[derive(Clone)]
559pub struct PassthroughConfig {
560 pub pattern: MessagePattern,
562 pub upstream_url: String,
564}
565
566impl PassthroughConfig {
567 pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
569 Self {
570 pattern,
571 upstream_url,
572 }
573 }
574
575 pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
577 Ok(Self {
578 pattern: MessagePattern::regex(regex)?,
579 upstream_url,
580 })
581 }
582}
583
584pub struct PassthroughHandler {
586 config: PassthroughConfig,
587 upstream_tx: Mutex<Option<UpstreamSender>>,
589}
590
591type UpstreamSender = futures_util::stream::SplitSink<
592 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
593 tokio_tungstenite::tungstenite::Message,
594>;
595
596impl PassthroughHandler {
597 pub fn new(config: PassthroughConfig) -> Self {
599 Self {
600 config,
601 upstream_tx: Mutex::new(None),
602 }
603 }
604
605 pub fn should_passthrough(&self, text: &str) -> bool {
607 self.config.pattern.matches(text)
608 }
609
610 pub fn upstream_url(&self) -> &str {
612 &self.config.upstream_url
613 }
614
615 async fn ensure_connected(
618 &self,
619 client_tx: &tokio::sync::mpsc::UnboundedSender<Message>,
620 ) -> HandlerResult<()> {
621 let mut guard = self.upstream_tx.lock().await;
622 if guard.is_some() {
623 return Ok(());
624 }
625
626 let url = &self.config.upstream_url;
627 tracing::info!(upstream = %url, "Connecting to upstream WebSocket server");
628
629 let (ws_stream, _response) = tokio_tungstenite::connect_async(url)
630 .await
631 .map_err(|e| HandlerError::ConnectionError(format!("Upstream connect failed: {e}")))?;
632
633 let (write, mut read) = ws_stream.split();
634 *guard = Some(write);
635
636 let client_tx = client_tx.clone();
638 tokio::spawn(async move {
639 while let Some(Ok(msg)) = read.next().await {
640 let axum_msg = match msg {
641 tokio_tungstenite::tungstenite::Message::Text(t) => {
642 Message::Text(t.to_string().into())
643 }
644 tokio_tungstenite::tungstenite::Message::Binary(b) => {
645 Message::Binary(b.to_vec().into())
646 }
647 tokio_tungstenite::tungstenite::Message::Ping(p) => {
648 Message::Ping(p.to_vec().into())
649 }
650 tokio_tungstenite::tungstenite::Message::Pong(p) => {
651 Message::Pong(p.to_vec().into())
652 }
653 tokio_tungstenite::tungstenite::Message::Close(_) => {
654 break;
655 }
656 tokio_tungstenite::tungstenite::Message::Frame(_) => continue,
657 };
658 if client_tx.send(axum_msg).is_err() {
659 break;
660 }
661 }
662 tracing::debug!("Upstream reader task finished");
663 });
664
665 Ok(())
666 }
667}
668
669#[async_trait]
670impl WsHandler for PassthroughHandler {
671 async fn on_connect(&self, ctx: &mut WsContext) -> HandlerResult<()> {
672 self.ensure_connected(&ctx.message_tx).await
673 }
674
675 async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
676 match &msg {
677 WsMessage::Text(text) if self.should_passthrough(text) => {
678 self.ensure_connected(&ctx.message_tx).await?;
679 let mut guard = self.upstream_tx.lock().await;
680 if let Some(ref mut writer) = *guard {
681 writer
682 .send(tokio_tungstenite::tungstenite::Message::Text(text.clone().into()))
683 .await
684 .map_err(|e| {
685 HandlerError::SendError(format!("Upstream send failed: {e}"))
686 })?;
687 }
688 }
689 WsMessage::Binary(data) => {
690 self.ensure_connected(&ctx.message_tx).await?;
691 let mut guard = self.upstream_tx.lock().await;
692 if let Some(ref mut writer) = *guard {
693 writer
694 .send(tokio_tungstenite::tungstenite::Message::Binary(data.clone().into()))
695 .await
696 .map_err(|e| {
697 HandlerError::SendError(format!("Upstream send failed: {e}"))
698 })?;
699 }
700 }
701 _ => {}
702 }
703 Ok(())
704 }
705
706 async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
707 let mut guard = self.upstream_tx.lock().await;
708 if let Some(mut writer) = guard.take() {
709 let _ = writer.send(tokio_tungstenite::tungstenite::Message::Close(None)).await;
710 }
711 Ok(())
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718
719 #[test]
722 fn test_ws_message_text_from_axum() {
723 let axum_msg = Message::Text("hello".to_string().into());
724 let ws_msg: WsMessage = axum_msg.into();
725 match ws_msg {
726 WsMessage::Text(text) => assert_eq!(text, "hello"),
727 _ => panic!("Expected Text message"),
728 }
729 }
730
731 #[test]
732 fn test_ws_message_binary_from_axum() {
733 let data = vec![1, 2, 3, 4];
734 let axum_msg = Message::Binary(data.clone().into());
735 let ws_msg: WsMessage = axum_msg.into();
736 match ws_msg {
737 WsMessage::Binary(bytes) => assert_eq!(bytes, data),
738 _ => panic!("Expected Binary message"),
739 }
740 }
741
742 #[test]
743 fn test_ws_message_ping_from_axum() {
744 let data = vec![1, 2];
745 let axum_msg = Message::Ping(data.clone().into());
746 let ws_msg: WsMessage = axum_msg.into();
747 match ws_msg {
748 WsMessage::Ping(bytes) => assert_eq!(bytes, data),
749 _ => panic!("Expected Ping message"),
750 }
751 }
752
753 #[test]
754 fn test_ws_message_pong_from_axum() {
755 let data = vec![3, 4];
756 let axum_msg = Message::Pong(data.clone().into());
757 let ws_msg: WsMessage = axum_msg.into();
758 match ws_msg {
759 WsMessage::Pong(bytes) => assert_eq!(bytes, data),
760 _ => panic!("Expected Pong message"),
761 }
762 }
763
764 #[test]
765 fn test_ws_message_close_from_axum() {
766 let axum_msg = Message::Close(None);
767 let ws_msg: WsMessage = axum_msg.into();
768 assert!(matches!(ws_msg, WsMessage::Close));
769 }
770
771 #[test]
772 fn test_ws_message_text_to_axum() {
773 let ws_msg = WsMessage::Text("hello".to_string());
774 let axum_msg: Message = ws_msg.into();
775 assert!(matches!(axum_msg, Message::Text(_)));
776 }
777
778 #[test]
779 fn test_ws_message_binary_to_axum() {
780 let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
781 let axum_msg: Message = ws_msg.into();
782 assert!(matches!(axum_msg, Message::Binary(_)));
783 }
784
785 #[test]
786 fn test_ws_message_close_to_axum() {
787 let ws_msg = WsMessage::Close;
788 let axum_msg: Message = ws_msg.into();
789 assert!(matches!(axum_msg, Message::Close(_)));
790 }
791
792 #[test]
795 fn test_message_pattern_regex() {
796 let pattern = MessagePattern::regex(r"^hello").unwrap();
797 assert!(pattern.matches("hello world"));
798 assert!(!pattern.matches("goodbye world"));
799 }
800
801 #[test]
802 fn test_message_pattern_regex_invalid() {
803 let result = MessagePattern::regex(r"[invalid");
804 assert!(result.is_err());
805 }
806
807 #[test]
808 fn test_message_pattern_exact() {
809 let pattern = MessagePattern::exact("hello");
810 assert!(pattern.matches("hello"));
811 assert!(!pattern.matches("hello world"));
812 }
813
814 #[test]
815 fn test_message_pattern_jsonpath() {
816 let pattern = MessagePattern::jsonpath("$.type");
817 assert!(pattern.matches(r#"{"type": "message"}"#));
818 assert!(!pattern.matches(r#"{"name": "test"}"#));
819 }
820
821 #[test]
822 fn test_message_pattern_jsonpath_nested() {
823 let pattern = MessagePattern::jsonpath("$.user.name");
824 assert!(pattern.matches(r#"{"user": {"name": "John"}}"#));
825 assert!(!pattern.matches(r#"{"user": {"email": "john@example.com"}}"#));
826 }
827
828 #[test]
829 fn test_message_pattern_jsonpath_invalid_json() {
830 let pattern = MessagePattern::jsonpath("$.type");
831 assert!(!pattern.matches("not json"));
832 }
833
834 #[test]
835 fn test_message_pattern_any() {
836 let pattern = MessagePattern::any();
837 assert!(pattern.matches("anything"));
838 assert!(pattern.matches(""));
839 assert!(pattern.matches(r#"{"json": true}"#));
840 }
841
842 #[test]
843 fn test_message_pattern_extract() {
844 let pattern = MessagePattern::jsonpath("$.type");
845 let result = pattern.extract(r#"{"type": "greeting", "data": "hello"}"#, "$.type");
846 assert_eq!(result, Some(serde_json::json!("greeting")));
847 }
848
849 #[test]
850 fn test_message_pattern_extract_nested() {
851 let pattern = MessagePattern::any();
852 let result = pattern.extract(r#"{"user": {"id": 123}}"#, "$.user.id");
853 assert_eq!(result, Some(serde_json::json!(123)));
854 }
855
856 #[test]
857 fn test_message_pattern_extract_not_found() {
858 let pattern = MessagePattern::any();
859 let result = pattern.extract(r#"{"type": "message"}"#, "$.nonexistent");
860 assert!(result.is_none());
861 }
862
863 #[test]
864 fn test_message_pattern_extract_invalid_json() {
865 let pattern = MessagePattern::any();
866 let result = pattern.extract("not json", "$.type");
867 assert!(result.is_none());
868 }
869
870 #[tokio::test]
873 async fn test_room_manager() {
874 let manager = RoomManager::new();
875
876 manager.join("conn1", "room1").await.unwrap();
878 manager.join("conn1", "room2").await.unwrap();
879 manager.join("conn2", "room1").await.unwrap();
880
881 let room1_members = manager.get_room_members("room1").await;
883 assert_eq!(room1_members.len(), 2);
884 assert!(room1_members.contains(&"conn1".to_string()));
885 assert!(room1_members.contains(&"conn2".to_string()));
886
887 let conn1_rooms = manager.get_connection_rooms("conn1").await;
889 assert_eq!(conn1_rooms.len(), 2);
890 assert!(conn1_rooms.contains(&"room1".to_string()));
891 assert!(conn1_rooms.contains(&"room2".to_string()));
892
893 manager.leave("conn1", "room1").await.unwrap();
895 let room1_members = manager.get_room_members("room1").await;
896 assert_eq!(room1_members.len(), 1);
897 assert!(room1_members.contains(&"conn2".to_string()));
898
899 manager.leave_all("conn1").await.unwrap();
901 let conn1_rooms = manager.get_connection_rooms("conn1").await;
902 assert_eq!(conn1_rooms.len(), 0);
903 }
904
905 #[tokio::test]
906 async fn test_room_manager_default() {
907 let manager = RoomManager::default();
908 manager.join("conn1", "room1").await.unwrap();
910 let members = manager.get_room_members("room1").await;
911 assert_eq!(members.len(), 1);
912 }
913
914 #[tokio::test]
915 async fn test_room_manager_empty_room() {
916 let manager = RoomManager::new();
917 let members = manager.get_room_members("nonexistent").await;
918 assert!(members.is_empty());
919 }
920
921 #[tokio::test]
922 async fn test_room_manager_empty_connection() {
923 let manager = RoomManager::new();
924 let rooms = manager.get_connection_rooms("nonexistent").await;
925 assert!(rooms.is_empty());
926 }
927
928 #[tokio::test]
929 async fn test_room_manager_leave_nonexistent() {
930 let manager = RoomManager::new();
931 let result = manager.leave("conn1", "room1").await;
933 assert!(result.is_ok());
934 }
935
936 #[tokio::test]
937 async fn test_room_manager_broadcaster() {
938 let manager = RoomManager::new();
939 manager.join("conn1", "room1").await.unwrap();
940
941 let broadcaster = manager.get_broadcaster("room1").await;
942 let mut receiver = broadcaster.subscribe();
943
944 broadcaster.send("hello".to_string()).unwrap();
946
947 let msg = receiver.recv().await.unwrap();
949 assert_eq!(msg, "hello");
950 }
951
952 #[tokio::test]
953 async fn test_room_manager_room_cleanup_on_last_leave() {
954 let manager = RoomManager::new();
955 manager.join("conn1", "room1").await.unwrap();
956 manager.leave("conn1", "room1").await.unwrap();
957
958 let members = manager.get_room_members("room1").await;
960 assert!(members.is_empty());
961 }
962
963 #[test]
966 fn test_message_router() {
967 let mut router = MessageRouter::new();
968
969 router
970 .on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
971 .on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
972
973 assert_eq!(router.route("ping"), Some("pong".to_string()));
974 assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
975 assert_eq!(router.route("goodbye"), None);
976 }
977
978 #[test]
979 fn test_message_router_default() {
980 let router = MessageRouter::default();
981 assert_eq!(router.route("anything"), None);
983 }
984
985 #[test]
986 fn test_message_router_first_match_wins() {
987 let mut router = MessageRouter::new();
988 router
989 .on(MessagePattern::any(), |_| Some("first".to_string()))
990 .on(MessagePattern::any(), |_| Some("second".to_string()));
991
992 assert_eq!(router.route("test"), Some("first".to_string()));
993 }
994
995 #[test]
996 fn test_message_router_handler_returns_none() {
997 let mut router = MessageRouter::new();
998 router
999 .on(MessagePattern::exact("skip"), |_| None)
1000 .on(MessagePattern::any(), |_| Some("fallback".to_string()));
1001
1002 assert_eq!(router.route("skip"), Some("fallback".to_string()));
1004 }
1005
1006 struct TestHandler;
1009
1010 #[async_trait]
1011 impl WsHandler for TestHandler {
1012 async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
1013 Ok(())
1014 }
1015 }
1016
1017 struct PathSpecificHandler {
1018 path: String,
1019 }
1020
1021 #[async_trait]
1022 impl WsHandler for PathSpecificHandler {
1023 async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
1024 Ok(())
1025 }
1026
1027 fn handles_path(&self, path: &str) -> bool {
1028 path == self.path
1029 }
1030 }
1031
1032 #[test]
1033 fn test_handler_registry_new() {
1034 let registry = HandlerRegistry::new();
1035 assert!(registry.is_empty());
1036 assert_eq!(registry.len(), 0);
1037 }
1038
1039 #[test]
1040 fn test_handler_registry_default() {
1041 let registry = HandlerRegistry::default();
1042 assert!(registry.is_empty());
1043 }
1044
1045 #[test]
1046 fn test_handler_registry_register() {
1047 let mut registry = HandlerRegistry::new();
1048 registry.register(TestHandler);
1049 assert!(!registry.is_empty());
1050 assert_eq!(registry.len(), 1);
1051 }
1052
1053 #[test]
1054 fn test_handler_registry_get_handlers() {
1055 let mut registry = HandlerRegistry::new();
1056 registry.register(TestHandler);
1057
1058 let handlers = registry.get_handlers("/any/path");
1059 assert_eq!(handlers.len(), 1);
1060 }
1061
1062 #[test]
1063 fn test_handler_registry_path_filtering() {
1064 let mut registry = HandlerRegistry::new();
1065 registry.register(PathSpecificHandler {
1066 path: "/ws/chat".to_string(),
1067 });
1068 registry.register(PathSpecificHandler {
1069 path: "/ws/events".to_string(),
1070 });
1071
1072 let chat_handlers = registry.get_handlers("/ws/chat");
1073 assert_eq!(chat_handlers.len(), 1);
1074
1075 let events_handlers = registry.get_handlers("/ws/events");
1076 assert_eq!(events_handlers.len(), 1);
1077
1078 let other_handlers = registry.get_handlers("/ws/other");
1079 assert!(other_handlers.is_empty());
1080 }
1081
1082 #[test]
1083 fn test_handler_registry_has_handler_for() {
1084 let mut registry = HandlerRegistry::new();
1085 registry.register(PathSpecificHandler {
1086 path: "/ws/chat".to_string(),
1087 });
1088
1089 assert!(registry.has_handler_for("/ws/chat"));
1090 assert!(!registry.has_handler_for("/ws/other"));
1091 }
1092
1093 #[test]
1094 fn test_handler_registry_clear() {
1095 let mut registry = HandlerRegistry::new();
1096 registry.register(TestHandler);
1097 registry.register(TestHandler);
1098 assert_eq!(registry.len(), 2);
1099
1100 registry.clear();
1101 assert!(registry.is_empty());
1102 }
1103
1104 #[test]
1105 fn test_handler_registry_with_hot_reload() {
1106 let registry = HandlerRegistry::with_hot_reload();
1107 assert!(registry.is_hot_reload_enabled());
1108 }
1109
1110 #[test]
1113 fn test_passthrough_config_new() {
1114 let config =
1115 PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1116 assert_eq!(config.upstream_url, "ws://upstream:8080");
1117 }
1118
1119 #[test]
1120 fn test_passthrough_config_regex() {
1121 let config =
1122 PassthroughConfig::regex(r"^forward", "ws://upstream:8080".to_string()).unwrap();
1123 assert!(config.pattern.matches("forward this"));
1124 assert!(!config.pattern.matches("don't forward"));
1125 }
1126
1127 #[test]
1128 fn test_passthrough_config_regex_invalid() {
1129 let result = PassthroughConfig::regex(r"[invalid", "ws://upstream:8080".to_string());
1130 assert!(result.is_err());
1131 }
1132
1133 #[test]
1136 fn test_passthrough_handler_should_passthrough() {
1137 let config =
1138 PassthroughConfig::regex(r"^proxy:", "ws://upstream:8080".to_string()).unwrap();
1139 let handler = PassthroughHandler::new(config);
1140
1141 assert!(handler.should_passthrough("proxy:hello"));
1142 assert!(!handler.should_passthrough("hello"));
1143 }
1144
1145 #[test]
1146 fn test_passthrough_handler_upstream_url() {
1147 let config =
1148 PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
1149 let handler = PassthroughHandler::new(config);
1150
1151 assert_eq!(handler.upstream_url(), "ws://upstream:8080");
1152 }
1153
1154 #[test]
1157 fn test_handler_error_send_error() {
1158 let err = HandlerError::SendError("connection closed".to_string());
1159 assert!(err.to_string().contains("send message"));
1160 assert!(err.to_string().contains("connection closed"));
1161 }
1162
1163 #[test]
1164 fn test_handler_error_json_error() {
1165 let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
1166 let err = HandlerError::JsonError(json_err);
1167 assert!(err.to_string().contains("JSON"));
1168 }
1169
1170 #[test]
1171 fn test_handler_error_pattern_error() {
1172 let err = HandlerError::PatternError("invalid regex".to_string());
1173 assert!(err.to_string().contains("Pattern"));
1174 }
1175
1176 #[test]
1177 fn test_handler_error_room_error() {
1178 let err = HandlerError::RoomError("room full".to_string());
1179 assert!(err.to_string().contains("Room"));
1180 }
1181
1182 #[test]
1183 fn test_handler_error_connection_error() {
1184 let err = HandlerError::ConnectionError("timeout".to_string());
1185 assert!(err.to_string().contains("Connection"));
1186 }
1187
1188 #[test]
1189 fn test_handler_error_generic() {
1190 let err = HandlerError::Generic("something went wrong".to_string());
1191 assert!(err.to_string().contains("something went wrong"));
1192 }
1193
1194 #[tokio::test]
1197 async fn test_ws_context_metadata() {
1198 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1199 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1200
1201 ctx.set_metadata("user", serde_json::json!({"id": 1})).await;
1203 let value = ctx.get_metadata("user").await;
1204 assert_eq!(value, Some(serde_json::json!({"id": 1})));
1205
1206 let missing = ctx.get_metadata("nonexistent").await;
1208 assert!(missing.is_none());
1209 }
1210
1211 #[tokio::test]
1212 async fn test_ws_context_send_text() {
1213 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1214 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1215
1216 ctx.send_text("hello").await.unwrap();
1217
1218 let msg = rx.recv().await.unwrap();
1219 assert!(matches!(msg, Message::Text(_)));
1220 }
1221
1222 #[tokio::test]
1223 async fn test_ws_context_send_binary() {
1224 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1225 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1226
1227 ctx.send_binary(vec![1, 2, 3]).await.unwrap();
1228
1229 let msg = rx.recv().await.unwrap();
1230 assert!(matches!(msg, Message::Binary(_)));
1231 }
1232
1233 #[tokio::test]
1234 async fn test_ws_context_send_json() {
1235 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1236 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1237
1238 ctx.send_json(&serde_json::json!({"type": "test"})).await.unwrap();
1239
1240 let msg = rx.recv().await.unwrap();
1241 assert!(matches!(msg, Message::Text(_)));
1242 }
1243
1244 #[tokio::test]
1245 async fn test_ws_context_rooms() {
1246 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1247 let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
1248
1249 ctx.join_room("chat").await.unwrap();
1251 ctx.join_room("notifications").await.unwrap();
1252
1253 let rooms = ctx.get_rooms().await;
1254 assert_eq!(rooms.len(), 2);
1255
1256 ctx.leave_room("chat").await.unwrap();
1258 let rooms = ctx.get_rooms().await;
1259 assert_eq!(rooms.len(), 1);
1260 }
1261}