1use async_trait::async_trait;
73use axum::extract::ws::Message;
74use regex::Regex;
75use serde_json::Value;
76use std::collections::{HashMap, HashSet};
77use std::sync::Arc;
78use tokio::sync::{broadcast, RwLock};
79
80pub type HandlerResult<T> = Result<T, HandlerError>;
82
83#[derive(Debug, thiserror::Error)]
85pub enum HandlerError {
86 #[error("Failed to send message: {0}")]
88 SendError(String),
89
90 #[error("Failed to parse JSON: {0}")]
92 JsonError(#[from] serde_json::Error),
93
94 #[error("Pattern matching error: {0}")]
96 PatternError(String),
97
98 #[error("Room operation failed: {0}")]
100 RoomError(String),
101
102 #[error("Connection error: {0}")]
104 ConnectionError(String),
105
106 #[error("Handler error: {0}")]
108 Generic(String),
109}
110
111#[derive(Debug, Clone)]
113pub enum WsMessage {
114 Text(String),
116 Binary(Vec<u8>),
118 Ping(Vec<u8>),
120 Pong(Vec<u8>),
122 Close,
124}
125
126impl From<Message> for WsMessage {
127 fn from(msg: Message) -> Self {
128 match msg {
129 Message::Text(text) => WsMessage::Text(text.to_string()),
130 Message::Binary(data) => WsMessage::Binary(data.to_vec()),
131 Message::Ping(data) => WsMessage::Ping(data.to_vec()),
132 Message::Pong(data) => WsMessage::Pong(data.to_vec()),
133 Message::Close(_) => WsMessage::Close,
134 }
135 }
136}
137
138impl From<WsMessage> for Message {
139 fn from(msg: WsMessage) -> Self {
140 match msg {
141 WsMessage::Text(text) => Message::Text(text.into()),
142 WsMessage::Binary(data) => Message::Binary(data.into()),
143 WsMessage::Ping(data) => Message::Ping(data.into()),
144 WsMessage::Pong(data) => Message::Pong(data.into()),
145 WsMessage::Close => Message::Close(None),
146 }
147 }
148}
149
150#[derive(Debug, Clone)]
152pub enum MessagePattern {
153 Regex(Regex),
155 JsonPath(String),
157 Exact(String),
159 Any,
161}
162
163impl MessagePattern {
164 pub fn regex(pattern: &str) -> HandlerResult<Self> {
166 Ok(MessagePattern::Regex(
167 Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
168 ))
169 }
170
171 pub fn jsonpath(query: &str) -> Self {
173 MessagePattern::JsonPath(query.to_string())
174 }
175
176 pub fn exact(text: &str) -> Self {
178 MessagePattern::Exact(text.to_string())
179 }
180
181 pub fn any() -> Self {
183 MessagePattern::Any
184 }
185
186 pub fn matches(&self, text: &str) -> bool {
188 match self {
189 MessagePattern::Regex(re) => re.is_match(text),
190 MessagePattern::JsonPath(query) => {
191 if let Ok(json) = serde_json::from_str::<Value>(text) {
193 if let Ok(selector) = jsonpath::Selector::new(query) {
195 let results: Vec<_> = selector.find(&json).collect();
196 !results.is_empty()
197 } else {
198 false
199 }
200 } else {
201 false
202 }
203 }
204 MessagePattern::Exact(expected) => text == expected,
205 MessagePattern::Any => true,
206 }
207 }
208
209 pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
211 if let Ok(json) = serde_json::from_str::<Value>(text) {
212 if let Ok(selector) = jsonpath::Selector::new(query) {
213 let results: Vec<_> = selector.find(&json).collect();
214 results.first().cloned().cloned()
215 } else {
216 None
217 }
218 } else {
219 None
220 }
221 }
222}
223
224pub type ConnectionId = String;
226
227#[derive(Clone)]
229pub struct RoomManager {
230 rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
231 connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
232 broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
233}
234
235impl RoomManager {
236 pub fn new() -> Self {
238 Self {
239 rooms: Arc::new(RwLock::new(HashMap::new())),
240 connections: Arc::new(RwLock::new(HashMap::new())),
241 broadcasters: Arc::new(RwLock::new(HashMap::new())),
242 }
243 }
244
245 pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
247 let mut rooms = self.rooms.write().await;
248 let mut connections = self.connections.write().await;
249
250 rooms
251 .entry(room.to_string())
252 .or_insert_with(HashSet::new)
253 .insert(conn_id.to_string());
254
255 connections
256 .entry(conn_id.to_string())
257 .or_insert_with(HashSet::new)
258 .insert(room.to_string());
259
260 Ok(())
261 }
262
263 pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
265 let mut rooms = self.rooms.write().await;
266 let mut connections = self.connections.write().await;
267
268 if let Some(room_members) = rooms.get_mut(room) {
269 room_members.remove(conn_id);
270 if room_members.is_empty() {
271 rooms.remove(room);
272 }
273 }
274
275 if let Some(conn_rooms) = connections.get_mut(conn_id) {
276 conn_rooms.remove(room);
277 if conn_rooms.is_empty() {
278 connections.remove(conn_id);
279 }
280 }
281
282 Ok(())
283 }
284
285 pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
287 let mut connections = self.connections.write().await;
288 if let Some(conn_rooms) = connections.remove(conn_id) {
289 let mut rooms = self.rooms.write().await;
290 for room in conn_rooms {
291 if let Some(room_members) = rooms.get_mut(&room) {
292 room_members.remove(conn_id);
293 if room_members.is_empty() {
294 rooms.remove(&room);
295 }
296 }
297 }
298 }
299 Ok(())
300 }
301
302 pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
304 let rooms = self.rooms.read().await;
305 rooms
306 .get(room)
307 .map(|members| members.iter().cloned().collect())
308 .unwrap_or_default()
309 }
310
311 pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
313 let connections = self.connections.read().await;
314 connections
315 .get(conn_id)
316 .map(|rooms| rooms.iter().cloned().collect())
317 .unwrap_or_default()
318 }
319
320 pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
322 let mut broadcasters = self.broadcasters.write().await;
323 broadcasters
324 .entry(room.to_string())
325 .or_insert_with(|| {
326 let (tx, _) = broadcast::channel(1024);
327 tx
328 })
329 .clone()
330 }
331}
332
333impl Default for RoomManager {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339pub struct WsContext {
341 pub connection_id: ConnectionId,
343 pub path: String,
345 room_manager: RoomManager,
347 message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
349 metadata: Arc<RwLock<HashMap<String, Value>>>,
351}
352
353impl WsContext {
354 pub fn new(
356 connection_id: ConnectionId,
357 path: String,
358 room_manager: RoomManager,
359 message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
360 ) -> Self {
361 Self {
362 connection_id,
363 path,
364 room_manager,
365 message_tx,
366 metadata: Arc::new(RwLock::new(HashMap::new())),
367 }
368 }
369
370 pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
372 self.message_tx
373 .send(Message::Text(text.to_string().into()))
374 .map_err(|e| HandlerError::SendError(e.to_string()))
375 }
376
377 pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
379 self.message_tx
380 .send(Message::Binary(data.into()))
381 .map_err(|e| HandlerError::SendError(e.to_string()))
382 }
383
384 pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
386 let text = serde_json::to_string(value)?;
387 self.send_text(&text).await
388 }
389
390 pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
392 self.room_manager.join(&self.connection_id, room).await
393 }
394
395 pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
397 self.room_manager.leave(&self.connection_id, room).await
398 }
399
400 pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
402 let broadcaster = self.room_manager.get_broadcaster(room).await;
403 broadcaster
404 .send(text.to_string())
405 .map_err(|e| HandlerError::RoomError(e.to_string()))?;
406 Ok(())
407 }
408
409 pub async fn get_rooms(&self) -> Vec<String> {
411 self.room_manager.get_connection_rooms(&self.connection_id).await
412 }
413
414 pub async fn set_metadata(&self, key: &str, value: Value) {
416 let mut metadata = self.metadata.write().await;
417 metadata.insert(key.to_string(), value);
418 }
419
420 pub async fn get_metadata(&self, key: &str) -> Option<Value> {
422 let metadata = self.metadata.read().await;
423 metadata.get(key).cloned()
424 }
425}
426
427#[async_trait]
429pub trait WsHandler: Send + Sync {
430 async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
432 Ok(())
433 }
434
435 async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
437
438 async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
440 Ok(())
441 }
442
443 fn handles_path(&self, _path: &str) -> bool {
445 true }
447}
448
449pub struct MessageRouter {
451 routes: Vec<(MessagePattern, Box<dyn Fn(String) -> Option<String> + Send + Sync>)>,
452}
453
454impl MessageRouter {
455 pub fn new() -> Self {
457 Self { routes: Vec::new() }
458 }
459
460 pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
462 where
463 F: Fn(String) -> Option<String> + Send + Sync + 'static,
464 {
465 self.routes.push((pattern, Box::new(handler)));
466 self
467 }
468
469 pub fn route(&self, text: &str) -> Option<String> {
471 for (pattern, handler) in &self.routes {
472 if pattern.matches(text) {
473 if let Some(response) = handler(text.to_string()) {
474 return Some(response);
475 }
476 }
477 }
478 None
479 }
480}
481
482impl Default for MessageRouter {
483 fn default() -> Self {
484 Self::new()
485 }
486}
487
488pub struct HandlerRegistry {
490 handlers: Vec<Arc<dyn WsHandler>>,
491 hot_reload_enabled: bool,
492}
493
494impl HandlerRegistry {
495 pub fn new() -> Self {
497 Self {
498 handlers: Vec::new(),
499 hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
500 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
501 .unwrap_or(false),
502 }
503 }
504
505 pub fn with_hot_reload() -> Self {
507 Self {
508 handlers: Vec::new(),
509 hot_reload_enabled: true,
510 }
511 }
512
513 pub fn is_hot_reload_enabled(&self) -> bool {
515 self.hot_reload_enabled
516 }
517
518 pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
520 self.handlers.push(Arc::new(handler));
521 self
522 }
523
524 pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
526 self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
527 }
528
529 pub fn has_handler_for(&self, path: &str) -> bool {
531 self.handlers.iter().any(|h| h.handles_path(path))
532 }
533
534 pub fn clear(&mut self) {
536 self.handlers.clear();
537 }
538
539 pub fn len(&self) -> usize {
541 self.handlers.len()
542 }
543
544 pub fn is_empty(&self) -> bool {
546 self.handlers.is_empty()
547 }
548}
549
550impl Default for HandlerRegistry {
551 fn default() -> Self {
552 Self::new()
553 }
554}
555
556#[derive(Clone)]
558pub struct PassthroughConfig {
559 pub pattern: MessagePattern,
561 pub upstream_url: String,
563}
564
565impl PassthroughConfig {
566 pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
568 Self {
569 pattern,
570 upstream_url,
571 }
572 }
573
574 pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
576 Ok(Self {
577 pattern: MessagePattern::regex(regex)?,
578 upstream_url,
579 })
580 }
581}
582
583pub struct PassthroughHandler {
585 config: PassthroughConfig,
586}
587
588impl PassthroughHandler {
589 pub fn new(config: PassthroughConfig) -> Self {
591 Self { config }
592 }
593
594 pub fn should_passthrough(&self, text: &str) -> bool {
596 self.config.pattern.matches(text)
597 }
598
599 pub fn upstream_url(&self) -> &str {
601 &self.config.upstream_url
602 }
603}
604
605#[async_trait]
606impl WsHandler for PassthroughHandler {
607 async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
608 if let WsMessage::Text(text) = &msg {
609 if self.should_passthrough(text) {
610 ctx.send_text(&format!("PASSTHROUGH({}): {}", self.config.upstream_url, text))
613 .await?;
614 return Ok(());
615 }
616 }
617 Ok(())
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624
625 #[test]
626 fn test_message_pattern_regex() {
627 let pattern = MessagePattern::regex(r"^hello").unwrap();
628 assert!(pattern.matches("hello world"));
629 assert!(!pattern.matches("goodbye world"));
630 }
631
632 #[test]
633 fn test_message_pattern_exact() {
634 let pattern = MessagePattern::exact("hello");
635 assert!(pattern.matches("hello"));
636 assert!(!pattern.matches("hello world"));
637 }
638
639 #[test]
640 fn test_message_pattern_jsonpath() {
641 let pattern = MessagePattern::jsonpath("$.type");
642 assert!(pattern.matches(r#"{"type": "message"}"#));
643 assert!(!pattern.matches(r#"{"name": "test"}"#));
644 }
645
646 #[tokio::test]
647 async fn test_room_manager() {
648 let manager = RoomManager::new();
649
650 manager.join("conn1", "room1").await.unwrap();
652 manager.join("conn1", "room2").await.unwrap();
653 manager.join("conn2", "room1").await.unwrap();
654
655 let room1_members = manager.get_room_members("room1").await;
657 assert_eq!(room1_members.len(), 2);
658 assert!(room1_members.contains(&"conn1".to_string()));
659 assert!(room1_members.contains(&"conn2".to_string()));
660
661 let conn1_rooms = manager.get_connection_rooms("conn1").await;
663 assert_eq!(conn1_rooms.len(), 2);
664 assert!(conn1_rooms.contains(&"room1".to_string()));
665 assert!(conn1_rooms.contains(&"room2".to_string()));
666
667 manager.leave("conn1", "room1").await.unwrap();
669 let room1_members = manager.get_room_members("room1").await;
670 assert_eq!(room1_members.len(), 1);
671 assert!(room1_members.contains(&"conn2".to_string()));
672
673 manager.leave_all("conn1").await.unwrap();
675 let conn1_rooms = manager.get_connection_rooms("conn1").await;
676 assert_eq!(conn1_rooms.len(), 0);
677 }
678
679 #[test]
680 fn test_message_router() {
681 let mut router = MessageRouter::new();
682
683 router
684 .on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
685 .on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
686
687 assert_eq!(router.route("ping"), Some("pong".to_string()));
688 assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
689 assert_eq!(router.route("goodbye"), None);
690 }
691}