1use async_trait::async_trait;
73use axum::extract::ws::Message;
74use regex::Regex;
75use serde_json::Value;
76use std::collections::{HashMap, HashSet};
77use std::sync::Arc;
78use tokio::sync::{broadcast, RwLock};
79
80pub type HandlerResult<T> = Result<T, HandlerError>;
82
83#[derive(Debug, thiserror::Error)]
85pub enum HandlerError {
86 #[error("Failed to send message: {0}")]
87 SendError(String),
88
89 #[error("Failed to parse JSON: {0}")]
90 JsonError(#[from] serde_json::Error),
91
92 #[error("Pattern matching error: {0}")]
93 PatternError(String),
94
95 #[error("Room operation failed: {0}")]
96 RoomError(String),
97
98 #[error("Connection error: {0}")]
99 ConnectionError(String),
100
101 #[error("Handler error: {0}")]
102 Generic(String),
103}
104
105#[derive(Debug, Clone)]
107pub enum WsMessage {
108 Text(String),
109 Binary(Vec<u8>),
110 Ping(Vec<u8>),
111 Pong(Vec<u8>),
112 Close,
113}
114
115impl From<Message> for WsMessage {
116 fn from(msg: Message) -> Self {
117 match msg {
118 Message::Text(text) => WsMessage::Text(text.to_string()),
119 Message::Binary(data) => WsMessage::Binary(data.to_vec()),
120 Message::Ping(data) => WsMessage::Ping(data.to_vec()),
121 Message::Pong(data) => WsMessage::Pong(data.to_vec()),
122 Message::Close(_) => WsMessage::Close,
123 }
124 }
125}
126
127impl From<WsMessage> for Message {
128 fn from(msg: WsMessage) -> Self {
129 match msg {
130 WsMessage::Text(text) => Message::Text(text.into()),
131 WsMessage::Binary(data) => Message::Binary(data.into()),
132 WsMessage::Ping(data) => Message::Ping(data.into()),
133 WsMessage::Pong(data) => Message::Pong(data.into()),
134 WsMessage::Close => Message::Close(None),
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
141pub enum MessagePattern {
142 Regex(Regex),
144 JsonPath(String),
146 Exact(String),
148 Any,
150}
151
152impl MessagePattern {
153 pub fn regex(pattern: &str) -> HandlerResult<Self> {
155 Ok(MessagePattern::Regex(
156 Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
157 ))
158 }
159
160 pub fn jsonpath(query: &str) -> Self {
162 MessagePattern::JsonPath(query.to_string())
163 }
164
165 pub fn exact(text: &str) -> Self {
167 MessagePattern::Exact(text.to_string())
168 }
169
170 pub fn any() -> Self {
172 MessagePattern::Any
173 }
174
175 pub fn matches(&self, text: &str) -> bool {
177 match self {
178 MessagePattern::Regex(re) => re.is_match(text),
179 MessagePattern::JsonPath(query) => {
180 if let Ok(json) = serde_json::from_str::<Value>(text) {
182 if let Ok(selector) = jsonpath::Selector::new(query) {
184 let results: Vec<_> = selector.find(&json).collect();
185 !results.is_empty()
186 } else {
187 false
188 }
189 } else {
190 false
191 }
192 }
193 MessagePattern::Exact(expected) => text == expected,
194 MessagePattern::Any => true,
195 }
196 }
197
198 pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
200 if let Ok(json) = serde_json::from_str::<Value>(text) {
201 if let Ok(selector) = jsonpath::Selector::new(query) {
202 let results: Vec<_> = selector.find(&json).collect();
203 results.first().cloned().cloned()
204 } else {
205 None
206 }
207 } else {
208 None
209 }
210 }
211}
212
213pub type ConnectionId = String;
215
216#[derive(Clone)]
218pub struct RoomManager {
219 rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
220 connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
221 broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
222}
223
224impl RoomManager {
225 pub fn new() -> Self {
227 Self {
228 rooms: Arc::new(RwLock::new(HashMap::new())),
229 connections: Arc::new(RwLock::new(HashMap::new())),
230 broadcasters: Arc::new(RwLock::new(HashMap::new())),
231 }
232 }
233
234 pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
236 let mut rooms = self.rooms.write().await;
237 let mut connections = self.connections.write().await;
238
239 rooms
240 .entry(room.to_string())
241 .or_insert_with(HashSet::new)
242 .insert(conn_id.to_string());
243
244 connections
245 .entry(conn_id.to_string())
246 .or_insert_with(HashSet::new)
247 .insert(room.to_string());
248
249 Ok(())
250 }
251
252 pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
254 let mut rooms = self.rooms.write().await;
255 let mut connections = self.connections.write().await;
256
257 if let Some(room_members) = rooms.get_mut(room) {
258 room_members.remove(conn_id);
259 if room_members.is_empty() {
260 rooms.remove(room);
261 }
262 }
263
264 if let Some(conn_rooms) = connections.get_mut(conn_id) {
265 conn_rooms.remove(room);
266 if conn_rooms.is_empty() {
267 connections.remove(conn_id);
268 }
269 }
270
271 Ok(())
272 }
273
274 pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
276 let mut connections = self.connections.write().await;
277 if let Some(conn_rooms) = connections.remove(conn_id) {
278 let mut rooms = self.rooms.write().await;
279 for room in conn_rooms {
280 if let Some(room_members) = rooms.get_mut(&room) {
281 room_members.remove(conn_id);
282 if room_members.is_empty() {
283 rooms.remove(&room);
284 }
285 }
286 }
287 }
288 Ok(())
289 }
290
291 pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
293 let rooms = self.rooms.read().await;
294 rooms
295 .get(room)
296 .map(|members| members.iter().cloned().collect())
297 .unwrap_or_default()
298 }
299
300 pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
302 let connections = self.connections.read().await;
303 connections
304 .get(conn_id)
305 .map(|rooms| rooms.iter().cloned().collect())
306 .unwrap_or_default()
307 }
308
309 pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
311 let mut broadcasters = self.broadcasters.write().await;
312 broadcasters
313 .entry(room.to_string())
314 .or_insert_with(|| {
315 let (tx, _) = broadcast::channel(1024);
316 tx
317 })
318 .clone()
319 }
320}
321
322impl Default for RoomManager {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328pub struct WsContext {
330 pub connection_id: ConnectionId,
332 pub path: String,
334 room_manager: RoomManager,
336 message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
338 metadata: Arc<RwLock<HashMap<String, Value>>>,
340}
341
342impl WsContext {
343 pub fn new(
345 connection_id: ConnectionId,
346 path: String,
347 room_manager: RoomManager,
348 message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
349 ) -> Self {
350 Self {
351 connection_id,
352 path,
353 room_manager,
354 message_tx,
355 metadata: Arc::new(RwLock::new(HashMap::new())),
356 }
357 }
358
359 pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
361 self.message_tx
362 .send(Message::Text(text.to_string().into()))
363 .map_err(|e| HandlerError::SendError(e.to_string()))
364 }
365
366 pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
368 self.message_tx
369 .send(Message::Binary(data.into()))
370 .map_err(|e| HandlerError::SendError(e.to_string()))
371 }
372
373 pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
375 let text = serde_json::to_string(value)?;
376 self.send_text(&text).await
377 }
378
379 pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
381 self.room_manager.join(&self.connection_id, room).await
382 }
383
384 pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
386 self.room_manager.leave(&self.connection_id, room).await
387 }
388
389 pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
391 let broadcaster = self.room_manager.get_broadcaster(room).await;
392 broadcaster
393 .send(text.to_string())
394 .map_err(|e| HandlerError::RoomError(e.to_string()))?;
395 Ok(())
396 }
397
398 pub async fn get_rooms(&self) -> Vec<String> {
400 self.room_manager.get_connection_rooms(&self.connection_id).await
401 }
402
403 pub async fn set_metadata(&self, key: &str, value: Value) {
405 let mut metadata = self.metadata.write().await;
406 metadata.insert(key.to_string(), value);
407 }
408
409 pub async fn get_metadata(&self, key: &str) -> Option<Value> {
411 let metadata = self.metadata.read().await;
412 metadata.get(key).cloned()
413 }
414}
415
416#[async_trait]
418pub trait WsHandler: Send + Sync {
419 async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
421 Ok(())
422 }
423
424 async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
426
427 async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
429 Ok(())
430 }
431
432 fn handles_path(&self, _path: &str) -> bool {
434 true }
436}
437
438pub struct MessageRouter {
440 routes: Vec<(MessagePattern, Box<dyn Fn(String) -> Option<String> + Send + Sync>)>,
441}
442
443impl MessageRouter {
444 pub fn new() -> Self {
446 Self { routes: Vec::new() }
447 }
448
449 pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
451 where
452 F: Fn(String) -> Option<String> + Send + Sync + 'static,
453 {
454 self.routes.push((pattern, Box::new(handler)));
455 self
456 }
457
458 pub fn route(&self, text: &str) -> Option<String> {
460 for (pattern, handler) in &self.routes {
461 if pattern.matches(text) {
462 if let Some(response) = handler(text.to_string()) {
463 return Some(response);
464 }
465 }
466 }
467 None
468 }
469}
470
471impl Default for MessageRouter {
472 fn default() -> Self {
473 Self::new()
474 }
475}
476
477pub struct HandlerRegistry {
479 handlers: Vec<Arc<dyn WsHandler>>,
480 hot_reload_enabled: bool,
481}
482
483impl HandlerRegistry {
484 pub fn new() -> Self {
486 Self {
487 handlers: Vec::new(),
488 hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
489 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
490 .unwrap_or(false),
491 }
492 }
493
494 pub fn with_hot_reload() -> Self {
496 Self {
497 handlers: Vec::new(),
498 hot_reload_enabled: true,
499 }
500 }
501
502 pub fn is_hot_reload_enabled(&self) -> bool {
504 self.hot_reload_enabled
505 }
506
507 pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
509 self.handlers.push(Arc::new(handler));
510 self
511 }
512
513 pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
515 self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
516 }
517
518 pub fn has_handler_for(&self, path: &str) -> bool {
520 self.handlers.iter().any(|h| h.handles_path(path))
521 }
522
523 pub fn clear(&mut self) {
525 self.handlers.clear();
526 }
527
528 pub fn len(&self) -> usize {
530 self.handlers.len()
531 }
532
533 pub fn is_empty(&self) -> bool {
535 self.handlers.is_empty()
536 }
537}
538
539impl Default for HandlerRegistry {
540 fn default() -> Self {
541 Self::new()
542 }
543}
544
545#[derive(Clone)]
547pub struct PassthroughConfig {
548 pub pattern: MessagePattern,
550 pub upstream_url: String,
552}
553
554impl PassthroughConfig {
555 pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
557 Self {
558 pattern,
559 upstream_url,
560 }
561 }
562
563 pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
565 Ok(Self {
566 pattern: MessagePattern::regex(regex)?,
567 upstream_url,
568 })
569 }
570}
571
572pub struct PassthroughHandler {
574 config: PassthroughConfig,
575}
576
577impl PassthroughHandler {
578 pub fn new(config: PassthroughConfig) -> Self {
580 Self { config }
581 }
582
583 pub fn should_passthrough(&self, text: &str) -> bool {
585 self.config.pattern.matches(text)
586 }
587
588 pub fn upstream_url(&self) -> &str {
590 &self.config.upstream_url
591 }
592}
593
594#[async_trait]
595impl WsHandler for PassthroughHandler {
596 async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
597 if let WsMessage::Text(text) = &msg {
598 if self.should_passthrough(text) {
599 ctx.send_text(&format!("PASSTHROUGH({}): {}", self.config.upstream_url, text))
602 .await?;
603 return Ok(());
604 }
605 }
606 Ok(())
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613
614 #[test]
615 fn test_message_pattern_regex() {
616 let pattern = MessagePattern::regex(r"^hello").unwrap();
617 assert!(pattern.matches("hello world"));
618 assert!(!pattern.matches("goodbye world"));
619 }
620
621 #[test]
622 fn test_message_pattern_exact() {
623 let pattern = MessagePattern::exact("hello");
624 assert!(pattern.matches("hello"));
625 assert!(!pattern.matches("hello world"));
626 }
627
628 #[test]
629 fn test_message_pattern_jsonpath() {
630 let pattern = MessagePattern::jsonpath("$.type");
631 assert!(pattern.matches(r#"{"type": "message"}"#));
632 assert!(!pattern.matches(r#"{"name": "test"}"#));
633 }
634
635 #[tokio::test]
636 async fn test_room_manager() {
637 let manager = RoomManager::new();
638
639 manager.join("conn1", "room1").await.unwrap();
641 manager.join("conn1", "room2").await.unwrap();
642 manager.join("conn2", "room1").await.unwrap();
643
644 let room1_members = manager.get_room_members("room1").await;
646 assert_eq!(room1_members.len(), 2);
647 assert!(room1_members.contains(&"conn1".to_string()));
648 assert!(room1_members.contains(&"conn2".to_string()));
649
650 let conn1_rooms = manager.get_connection_rooms("conn1").await;
652 assert_eq!(conn1_rooms.len(), 2);
653 assert!(conn1_rooms.contains(&"room1".to_string()));
654 assert!(conn1_rooms.contains(&"room2".to_string()));
655
656 manager.leave("conn1", "room1").await.unwrap();
658 let room1_members = manager.get_room_members("room1").await;
659 assert_eq!(room1_members.len(), 1);
660 assert!(room1_members.contains(&"conn2".to_string()));
661
662 manager.leave_all("conn1").await.unwrap();
664 let conn1_rooms = manager.get_connection_rooms("conn1").await;
665 assert_eq!(conn1_rooms.len(), 0);
666 }
667
668 #[test]
669 fn test_message_router() {
670 let mut router = MessageRouter::new();
671
672 router
673 .on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
674 .on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
675
676 assert_eq!(router.route("ping"), Some("pong".to_string()));
677 assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
678 assert_eq!(router.route("goodbye"), None);
679 }
680}