use async_trait::async_trait;
use axum::extract::ws::Message;
use futures_util::{SinkExt, StreamExt};
use regex::Regex;
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::{broadcast, Mutex, RwLock};
pub type HandlerResult<T> = Result<T, HandlerError>;
#[derive(Debug, thiserror::Error)]
pub enum HandlerError {
#[error("Failed to send message: {0}")]
SendError(String),
#[error("Failed to parse JSON: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Pattern matching error: {0}")]
PatternError(String),
#[error("Room operation failed: {0}")]
RoomError(String),
#[error("Connection error: {0}")]
ConnectionError(String),
#[error("Handler error: {0}")]
Generic(String),
}
#[derive(Debug, Clone)]
pub enum WsMessage {
Text(String),
Binary(Vec<u8>),
Ping(Vec<u8>),
Pong(Vec<u8>),
Close,
}
impl From<Message> for WsMessage {
fn from(msg: Message) -> Self {
match msg {
Message::Text(text) => WsMessage::Text(text.to_string()),
Message::Binary(data) => WsMessage::Binary(data.to_vec()),
Message::Ping(data) => WsMessage::Ping(data.to_vec()),
Message::Pong(data) => WsMessage::Pong(data.to_vec()),
Message::Close(_) => WsMessage::Close,
}
}
}
impl From<WsMessage> for Message {
fn from(msg: WsMessage) -> Self {
match msg {
WsMessage::Text(text) => Message::Text(text.into()),
WsMessage::Binary(data) => Message::Binary(data.into()),
WsMessage::Ping(data) => Message::Ping(data.into()),
WsMessage::Pong(data) => Message::Pong(data.into()),
WsMessage::Close => Message::Close(None),
}
}
}
#[derive(Debug, Clone)]
pub enum MessagePattern {
Regex(Regex),
JsonPath(String),
Exact(String),
Any,
}
impl MessagePattern {
pub fn regex(pattern: &str) -> HandlerResult<Self> {
Ok(MessagePattern::Regex(
Regex::new(pattern).map_err(|e| HandlerError::PatternError(e.to_string()))?,
))
}
pub fn jsonpath(query: &str) -> Self {
MessagePattern::JsonPath(query.to_string())
}
pub fn exact(text: &str) -> Self {
MessagePattern::Exact(text.to_string())
}
pub fn any() -> Self {
MessagePattern::Any
}
pub fn matches(&self, text: &str) -> bool {
match self {
MessagePattern::Regex(re) => re.is_match(text),
MessagePattern::JsonPath(query) => {
if let Ok(json) = serde_json::from_str::<Value>(text) {
if let Ok(selector) = jsonpath::Selector::new(query) {
let results: Vec<_> = selector.find(&json).collect();
!results.is_empty()
} else {
false
}
} else {
false
}
}
MessagePattern::Exact(expected) => text == expected,
MessagePattern::Any => true,
}
}
pub fn extract(&self, text: &str, query: &str) -> Option<Value> {
if let Ok(json) = serde_json::from_str::<Value>(text) {
if let Ok(selector) = jsonpath::Selector::new(query) {
let results: Vec<_> = selector.find(&json).collect();
results.first().cloned().cloned()
} else {
None
}
} else {
None
}
}
}
pub type ConnectionId = String;
#[derive(Clone)]
pub struct RoomManager {
rooms: Arc<RwLock<HashMap<String, HashSet<ConnectionId>>>>,
connections: Arc<RwLock<HashMap<ConnectionId, HashSet<String>>>>,
broadcasters: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
}
impl RoomManager {
pub fn new() -> Self {
Self {
rooms: Arc::new(RwLock::new(HashMap::new())),
connections: Arc::new(RwLock::new(HashMap::new())),
broadcasters: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn join(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
let mut rooms = self.rooms.write().await;
let mut connections = self.connections.write().await;
rooms
.entry(room.to_string())
.or_insert_with(HashSet::new)
.insert(conn_id.to_string());
connections
.entry(conn_id.to_string())
.or_insert_with(HashSet::new)
.insert(room.to_string());
Ok(())
}
pub async fn leave(&self, conn_id: &str, room: &str) -> HandlerResult<()> {
let mut rooms = self.rooms.write().await;
let mut connections = self.connections.write().await;
if let Some(room_members) = rooms.get_mut(room) {
room_members.remove(conn_id);
if room_members.is_empty() {
rooms.remove(room);
}
}
if let Some(conn_rooms) = connections.get_mut(conn_id) {
conn_rooms.remove(room);
if conn_rooms.is_empty() {
connections.remove(conn_id);
}
}
Ok(())
}
pub async fn leave_all(&self, conn_id: &str) -> HandlerResult<()> {
let mut connections = self.connections.write().await;
if let Some(conn_rooms) = connections.remove(conn_id) {
let mut rooms = self.rooms.write().await;
for room in conn_rooms {
if let Some(room_members) = rooms.get_mut(&room) {
room_members.remove(conn_id);
if room_members.is_empty() {
rooms.remove(&room);
}
}
}
}
Ok(())
}
pub async fn get_room_members(&self, room: &str) -> Vec<ConnectionId> {
let rooms = self.rooms.read().await;
rooms
.get(room)
.map(|members| members.iter().cloned().collect())
.unwrap_or_default()
}
pub async fn get_connection_rooms(&self, conn_id: &str) -> Vec<String> {
let connections = self.connections.read().await;
connections
.get(conn_id)
.map(|rooms| rooms.iter().cloned().collect())
.unwrap_or_default()
}
pub async fn get_broadcaster(&self, room: &str) -> broadcast::Sender<String> {
let mut broadcasters = self.broadcasters.write().await;
broadcasters
.entry(room.to_string())
.or_insert_with(|| {
let (tx, _) = broadcast::channel(1024);
tx
})
.clone()
}
}
impl Default for RoomManager {
fn default() -> Self {
Self::new()
}
}
pub struct WsContext {
pub connection_id: ConnectionId,
pub path: String,
room_manager: RoomManager,
message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
metadata: Arc<RwLock<HashMap<String, Value>>>,
}
impl WsContext {
pub fn new(
connection_id: ConnectionId,
path: String,
room_manager: RoomManager,
message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
) -> Self {
Self {
connection_id,
path,
room_manager,
message_tx,
metadata: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn send_text(&self, text: &str) -> HandlerResult<()> {
self.message_tx
.send(Message::Text(text.to_string().into()))
.map_err(|e| HandlerError::SendError(e.to_string()))
}
pub async fn send_binary(&self, data: Vec<u8>) -> HandlerResult<()> {
self.message_tx
.send(Message::Binary(data.into()))
.map_err(|e| HandlerError::SendError(e.to_string()))
}
pub async fn send_json(&self, value: &Value) -> HandlerResult<()> {
let text = serde_json::to_string(value)?;
self.send_text(&text).await
}
pub async fn join_room(&self, room: &str) -> HandlerResult<()> {
self.room_manager.join(&self.connection_id, room).await
}
pub async fn leave_room(&self, room: &str) -> HandlerResult<()> {
self.room_manager.leave(&self.connection_id, room).await
}
pub async fn broadcast_to_room(&self, room: &str, text: &str) -> HandlerResult<()> {
let broadcaster = self.room_manager.get_broadcaster(room).await;
broadcaster
.send(text.to_string())
.map_err(|e| HandlerError::RoomError(e.to_string()))?;
Ok(())
}
pub async fn get_rooms(&self) -> Vec<String> {
self.room_manager.get_connection_rooms(&self.connection_id).await
}
pub async fn set_metadata(&self, key: &str, value: Value) {
let mut metadata = self.metadata.write().await;
metadata.insert(key.to_string(), value);
}
pub async fn get_metadata(&self, key: &str) -> Option<Value> {
let metadata = self.metadata.read().await;
metadata.get(key).cloned()
}
}
#[async_trait]
pub trait WsHandler: Send + Sync {
async fn on_connect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
Ok(())
}
async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()>;
async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
Ok(())
}
fn handles_path(&self, _path: &str) -> bool {
true }
}
type MessageHandler = Box<dyn Fn(String) -> Option<String> + Send + Sync>;
pub struct MessageRouter {
routes: Vec<(MessagePattern, MessageHandler)>,
}
impl MessageRouter {
pub fn new() -> Self {
Self { routes: Vec::new() }
}
pub fn on<F>(&mut self, pattern: MessagePattern, handler: F) -> &mut Self
where
F: Fn(String) -> Option<String> + Send + Sync + 'static,
{
self.routes.push((pattern, Box::new(handler)));
self
}
pub fn route(&self, text: &str) -> Option<String> {
for (pattern, handler) in &self.routes {
if pattern.matches(text) {
if let Some(response) = handler(text.to_string()) {
return Some(response);
}
}
}
None
}
}
impl Default for MessageRouter {
fn default() -> Self {
Self::new()
}
}
pub struct HandlerRegistry {
handlers: Vec<Arc<dyn WsHandler>>,
hot_reload_enabled: bool,
}
impl HandlerRegistry {
pub fn new() -> Self {
Self {
handlers: Vec::new(),
hot_reload_enabled: std::env::var("MOCKFORGE_WS_HOTRELOAD")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false),
}
}
pub fn with_hot_reload() -> Self {
Self {
handlers: Vec::new(),
hot_reload_enabled: true,
}
}
pub fn is_hot_reload_enabled(&self) -> bool {
self.hot_reload_enabled
}
pub fn register<H: WsHandler + 'static>(&mut self, handler: H) -> &mut Self {
self.handlers.push(Arc::new(handler));
self
}
pub fn get_handlers(&self, path: &str) -> Vec<Arc<dyn WsHandler>> {
self.handlers.iter().filter(|h| h.handles_path(path)).cloned().collect()
}
pub fn has_handler_for(&self, path: &str) -> bool {
self.handlers.iter().any(|h| h.handles_path(path))
}
pub fn clear(&mut self) {
self.handlers.clear();
}
pub fn len(&self) -> usize {
self.handlers.len()
}
pub fn is_empty(&self) -> bool {
self.handlers.is_empty()
}
}
impl Default for HandlerRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct PassthroughConfig {
pub pattern: MessagePattern,
pub upstream_url: String,
}
impl PassthroughConfig {
pub fn new(pattern: MessagePattern, upstream_url: String) -> Self {
Self {
pattern,
upstream_url,
}
}
pub fn regex(regex: &str, upstream_url: String) -> HandlerResult<Self> {
Ok(Self {
pattern: MessagePattern::regex(regex)?,
upstream_url,
})
}
}
pub struct PassthroughHandler {
config: PassthroughConfig,
upstream_tx: Mutex<Option<UpstreamSender>>,
}
type UpstreamSender = futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>;
impl PassthroughHandler {
pub fn new(config: PassthroughConfig) -> Self {
Self {
config,
upstream_tx: Mutex::new(None),
}
}
pub fn should_passthrough(&self, text: &str) -> bool {
self.config.pattern.matches(text)
}
pub fn upstream_url(&self) -> &str {
&self.config.upstream_url
}
async fn ensure_connected(
&self,
client_tx: &tokio::sync::mpsc::UnboundedSender<Message>,
) -> HandlerResult<()> {
let mut guard = self.upstream_tx.lock().await;
if guard.is_some() {
return Ok(());
}
let url = &self.config.upstream_url;
tracing::info!(upstream = %url, "Connecting to upstream WebSocket server");
let (ws_stream, _response) = tokio_tungstenite::connect_async(url)
.await
.map_err(|e| HandlerError::ConnectionError(format!("Upstream connect failed: {e}")))?;
let (write, mut read) = ws_stream.split();
*guard = Some(write);
let client_tx = client_tx.clone();
tokio::spawn(async move {
while let Some(Ok(msg)) = read.next().await {
let axum_msg = match msg {
tokio_tungstenite::tungstenite::Message::Text(t) => {
Message::Text(t.to_string().into())
}
tokio_tungstenite::tungstenite::Message::Binary(b) => {
Message::Binary(b.to_vec().into())
}
tokio_tungstenite::tungstenite::Message::Ping(p) => {
Message::Ping(p.to_vec().into())
}
tokio_tungstenite::tungstenite::Message::Pong(p) => {
Message::Pong(p.to_vec().into())
}
tokio_tungstenite::tungstenite::Message::Close(_) => {
break;
}
tokio_tungstenite::tungstenite::Message::Frame(_) => continue,
};
if client_tx.send(axum_msg).is_err() {
break;
}
}
tracing::debug!("Upstream reader task finished");
});
Ok(())
}
}
#[async_trait]
impl WsHandler for PassthroughHandler {
async fn on_connect(&self, ctx: &mut WsContext) -> HandlerResult<()> {
self.ensure_connected(&ctx.message_tx).await
}
async fn on_message(&self, ctx: &mut WsContext, msg: WsMessage) -> HandlerResult<()> {
match &msg {
WsMessage::Text(text) if self.should_passthrough(text) => {
self.ensure_connected(&ctx.message_tx).await?;
let mut guard = self.upstream_tx.lock().await;
if let Some(ref mut writer) = *guard {
writer
.send(tokio_tungstenite::tungstenite::Message::Text(text.clone().into()))
.await
.map_err(|e| {
HandlerError::SendError(format!("Upstream send failed: {e}"))
})?;
}
}
WsMessage::Binary(data) => {
self.ensure_connected(&ctx.message_tx).await?;
let mut guard = self.upstream_tx.lock().await;
if let Some(ref mut writer) = *guard {
writer
.send(tokio_tungstenite::tungstenite::Message::Binary(data.clone().into()))
.await
.map_err(|e| {
HandlerError::SendError(format!("Upstream send failed: {e}"))
})?;
}
}
_ => {}
}
Ok(())
}
async fn on_disconnect(&self, _ctx: &mut WsContext) -> HandlerResult<()> {
let mut guard = self.upstream_tx.lock().await;
if let Some(mut writer) = guard.take() {
let _ = writer.send(tokio_tungstenite::tungstenite::Message::Close(None)).await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ws_message_text_from_axum() {
let axum_msg = Message::Text("hello".to_string().into());
let ws_msg: WsMessage = axum_msg.into();
match ws_msg {
WsMessage::Text(text) => assert_eq!(text, "hello"),
_ => panic!("Expected Text message"),
}
}
#[test]
fn test_ws_message_binary_from_axum() {
let data = vec![1, 2, 3, 4];
let axum_msg = Message::Binary(data.clone().into());
let ws_msg: WsMessage = axum_msg.into();
match ws_msg {
WsMessage::Binary(bytes) => assert_eq!(bytes, data),
_ => panic!("Expected Binary message"),
}
}
#[test]
fn test_ws_message_ping_from_axum() {
let data = vec![1, 2];
let axum_msg = Message::Ping(data.clone().into());
let ws_msg: WsMessage = axum_msg.into();
match ws_msg {
WsMessage::Ping(bytes) => assert_eq!(bytes, data),
_ => panic!("Expected Ping message"),
}
}
#[test]
fn test_ws_message_pong_from_axum() {
let data = vec![3, 4];
let axum_msg = Message::Pong(data.clone().into());
let ws_msg: WsMessage = axum_msg.into();
match ws_msg {
WsMessage::Pong(bytes) => assert_eq!(bytes, data),
_ => panic!("Expected Pong message"),
}
}
#[test]
fn test_ws_message_close_from_axum() {
let axum_msg = Message::Close(None);
let ws_msg: WsMessage = axum_msg.into();
assert!(matches!(ws_msg, WsMessage::Close));
}
#[test]
fn test_ws_message_text_to_axum() {
let ws_msg = WsMessage::Text("hello".to_string());
let axum_msg: Message = ws_msg.into();
assert!(matches!(axum_msg, Message::Text(_)));
}
#[test]
fn test_ws_message_binary_to_axum() {
let ws_msg = WsMessage::Binary(vec![1, 2, 3]);
let axum_msg: Message = ws_msg.into();
assert!(matches!(axum_msg, Message::Binary(_)));
}
#[test]
fn test_ws_message_close_to_axum() {
let ws_msg = WsMessage::Close;
let axum_msg: Message = ws_msg.into();
assert!(matches!(axum_msg, Message::Close(_)));
}
#[test]
fn test_message_pattern_regex() {
let pattern = MessagePattern::regex(r"^hello").unwrap();
assert!(pattern.matches("hello world"));
assert!(!pattern.matches("goodbye world"));
}
#[test]
fn test_message_pattern_regex_invalid() {
let result = MessagePattern::regex(r"[invalid");
assert!(result.is_err());
}
#[test]
fn test_message_pattern_exact() {
let pattern = MessagePattern::exact("hello");
assert!(pattern.matches("hello"));
assert!(!pattern.matches("hello world"));
}
#[test]
fn test_message_pattern_jsonpath() {
let pattern = MessagePattern::jsonpath("$.type");
assert!(pattern.matches(r#"{"type": "message"}"#));
assert!(!pattern.matches(r#"{"name": "test"}"#));
}
#[test]
fn test_message_pattern_jsonpath_nested() {
let pattern = MessagePattern::jsonpath("$.user.name");
assert!(pattern.matches(r#"{"user": {"name": "John"}}"#));
assert!(!pattern.matches(r#"{"user": {"email": "john@example.com"}}"#));
}
#[test]
fn test_message_pattern_jsonpath_invalid_json() {
let pattern = MessagePattern::jsonpath("$.type");
assert!(!pattern.matches("not json"));
}
#[test]
fn test_message_pattern_any() {
let pattern = MessagePattern::any();
assert!(pattern.matches("anything"));
assert!(pattern.matches(""));
assert!(pattern.matches(r#"{"json": true}"#));
}
#[test]
fn test_message_pattern_extract() {
let pattern = MessagePattern::jsonpath("$.type");
let result = pattern.extract(r#"{"type": "greeting", "data": "hello"}"#, "$.type");
assert_eq!(result, Some(serde_json::json!("greeting")));
}
#[test]
fn test_message_pattern_extract_nested() {
let pattern = MessagePattern::any();
let result = pattern.extract(r#"{"user": {"id": 123}}"#, "$.user.id");
assert_eq!(result, Some(serde_json::json!(123)));
}
#[test]
fn test_message_pattern_extract_not_found() {
let pattern = MessagePattern::any();
let result = pattern.extract(r#"{"type": "message"}"#, "$.nonexistent");
assert!(result.is_none());
}
#[test]
fn test_message_pattern_extract_invalid_json() {
let pattern = MessagePattern::any();
let result = pattern.extract("not json", "$.type");
assert!(result.is_none());
}
#[tokio::test]
async fn test_room_manager() {
let manager = RoomManager::new();
manager.join("conn1", "room1").await.unwrap();
manager.join("conn1", "room2").await.unwrap();
manager.join("conn2", "room1").await.unwrap();
let room1_members = manager.get_room_members("room1").await;
assert_eq!(room1_members.len(), 2);
assert!(room1_members.contains(&"conn1".to_string()));
assert!(room1_members.contains(&"conn2".to_string()));
let conn1_rooms = manager.get_connection_rooms("conn1").await;
assert_eq!(conn1_rooms.len(), 2);
assert!(conn1_rooms.contains(&"room1".to_string()));
assert!(conn1_rooms.contains(&"room2".to_string()));
manager.leave("conn1", "room1").await.unwrap();
let room1_members = manager.get_room_members("room1").await;
assert_eq!(room1_members.len(), 1);
assert!(room1_members.contains(&"conn2".to_string()));
manager.leave_all("conn1").await.unwrap();
let conn1_rooms = manager.get_connection_rooms("conn1").await;
assert_eq!(conn1_rooms.len(), 0);
}
#[tokio::test]
async fn test_room_manager_default() {
let manager = RoomManager::default();
manager.join("conn1", "room1").await.unwrap();
let members = manager.get_room_members("room1").await;
assert_eq!(members.len(), 1);
}
#[tokio::test]
async fn test_room_manager_empty_room() {
let manager = RoomManager::new();
let members = manager.get_room_members("nonexistent").await;
assert!(members.is_empty());
}
#[tokio::test]
async fn test_room_manager_empty_connection() {
let manager = RoomManager::new();
let rooms = manager.get_connection_rooms("nonexistent").await;
assert!(rooms.is_empty());
}
#[tokio::test]
async fn test_room_manager_leave_nonexistent() {
let manager = RoomManager::new();
let result = manager.leave("conn1", "room1").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_room_manager_broadcaster() {
let manager = RoomManager::new();
manager.join("conn1", "room1").await.unwrap();
let broadcaster = manager.get_broadcaster("room1").await;
let mut receiver = broadcaster.subscribe();
broadcaster.send("hello".to_string()).unwrap();
let msg = receiver.recv().await.unwrap();
assert_eq!(msg, "hello");
}
#[tokio::test]
async fn test_room_manager_room_cleanup_on_last_leave() {
let manager = RoomManager::new();
manager.join("conn1", "room1").await.unwrap();
manager.leave("conn1", "room1").await.unwrap();
let members = manager.get_room_members("room1").await;
assert!(members.is_empty());
}
#[test]
fn test_message_router() {
let mut router = MessageRouter::new();
router
.on(MessagePattern::exact("ping"), |_| Some("pong".to_string()))
.on(MessagePattern::regex(r"^hello").unwrap(), |_| Some("hi there!".to_string()));
assert_eq!(router.route("ping"), Some("pong".to_string()));
assert_eq!(router.route("hello world"), Some("hi there!".to_string()));
assert_eq!(router.route("goodbye"), None);
}
#[test]
fn test_message_router_default() {
let router = MessageRouter::default();
assert_eq!(router.route("anything"), None);
}
#[test]
fn test_message_router_first_match_wins() {
let mut router = MessageRouter::new();
router
.on(MessagePattern::any(), |_| Some("first".to_string()))
.on(MessagePattern::any(), |_| Some("second".to_string()));
assert_eq!(router.route("test"), Some("first".to_string()));
}
#[test]
fn test_message_router_handler_returns_none() {
let mut router = MessageRouter::new();
router
.on(MessagePattern::exact("skip"), |_| None)
.on(MessagePattern::any(), |_| Some("fallback".to_string()));
assert_eq!(router.route("skip"), Some("fallback".to_string()));
}
struct TestHandler;
#[async_trait]
impl WsHandler for TestHandler {
async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
Ok(())
}
}
struct PathSpecificHandler {
path: String,
}
#[async_trait]
impl WsHandler for PathSpecificHandler {
async fn on_message(&self, _ctx: &mut WsContext, _msg: WsMessage) -> HandlerResult<()> {
Ok(())
}
fn handles_path(&self, path: &str) -> bool {
path == self.path
}
}
#[test]
fn test_handler_registry_new() {
let registry = HandlerRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
}
#[test]
fn test_handler_registry_default() {
let registry = HandlerRegistry::default();
assert!(registry.is_empty());
}
#[test]
fn test_handler_registry_register() {
let mut registry = HandlerRegistry::new();
registry.register(TestHandler);
assert!(!registry.is_empty());
assert_eq!(registry.len(), 1);
}
#[test]
fn test_handler_registry_get_handlers() {
let mut registry = HandlerRegistry::new();
registry.register(TestHandler);
let handlers = registry.get_handlers("/any/path");
assert_eq!(handlers.len(), 1);
}
#[test]
fn test_handler_registry_path_filtering() {
let mut registry = HandlerRegistry::new();
registry.register(PathSpecificHandler {
path: "/ws/chat".to_string(),
});
registry.register(PathSpecificHandler {
path: "/ws/events".to_string(),
});
let chat_handlers = registry.get_handlers("/ws/chat");
assert_eq!(chat_handlers.len(), 1);
let events_handlers = registry.get_handlers("/ws/events");
assert_eq!(events_handlers.len(), 1);
let other_handlers = registry.get_handlers("/ws/other");
assert!(other_handlers.is_empty());
}
#[test]
fn test_handler_registry_has_handler_for() {
let mut registry = HandlerRegistry::new();
registry.register(PathSpecificHandler {
path: "/ws/chat".to_string(),
});
assert!(registry.has_handler_for("/ws/chat"));
assert!(!registry.has_handler_for("/ws/other"));
}
#[test]
fn test_handler_registry_clear() {
let mut registry = HandlerRegistry::new();
registry.register(TestHandler);
registry.register(TestHandler);
assert_eq!(registry.len(), 2);
registry.clear();
assert!(registry.is_empty());
}
#[test]
fn test_handler_registry_with_hot_reload() {
let registry = HandlerRegistry::with_hot_reload();
assert!(registry.is_hot_reload_enabled());
}
#[test]
fn test_passthrough_config_new() {
let config =
PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
assert_eq!(config.upstream_url, "ws://upstream:8080");
}
#[test]
fn test_passthrough_config_regex() {
let config =
PassthroughConfig::regex(r"^forward", "ws://upstream:8080".to_string()).unwrap();
assert!(config.pattern.matches("forward this"));
assert!(!config.pattern.matches("don't forward"));
}
#[test]
fn test_passthrough_config_regex_invalid() {
let result = PassthroughConfig::regex(r"[invalid", "ws://upstream:8080".to_string());
assert!(result.is_err());
}
#[test]
fn test_passthrough_handler_should_passthrough() {
let config =
PassthroughConfig::regex(r"^proxy:", "ws://upstream:8080".to_string()).unwrap();
let handler = PassthroughHandler::new(config);
assert!(handler.should_passthrough("proxy:hello"));
assert!(!handler.should_passthrough("hello"));
}
#[test]
fn test_passthrough_handler_upstream_url() {
let config =
PassthroughConfig::new(MessagePattern::any(), "ws://upstream:8080".to_string());
let handler = PassthroughHandler::new(config);
assert_eq!(handler.upstream_url(), "ws://upstream:8080");
}
#[test]
fn test_handler_error_send_error() {
let err = HandlerError::SendError("connection closed".to_string());
assert!(err.to_string().contains("send message"));
assert!(err.to_string().contains("connection closed"));
}
#[test]
fn test_handler_error_json_error() {
let json_err = serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
let err = HandlerError::JsonError(json_err);
assert!(err.to_string().contains("JSON"));
}
#[test]
fn test_handler_error_pattern_error() {
let err = HandlerError::PatternError("invalid regex".to_string());
assert!(err.to_string().contains("Pattern"));
}
#[test]
fn test_handler_error_room_error() {
let err = HandlerError::RoomError("room full".to_string());
assert!(err.to_string().contains("Room"));
}
#[test]
fn test_handler_error_connection_error() {
let err = HandlerError::ConnectionError("timeout".to_string());
assert!(err.to_string().contains("Connection"));
}
#[test]
fn test_handler_error_generic() {
let err = HandlerError::Generic("something went wrong".to_string());
assert!(err.to_string().contains("something went wrong"));
}
#[tokio::test]
async fn test_ws_context_metadata() {
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
ctx.set_metadata("user", serde_json::json!({"id": 1})).await;
let value = ctx.get_metadata("user").await;
assert_eq!(value, Some(serde_json::json!({"id": 1})));
let missing = ctx.get_metadata("nonexistent").await;
assert!(missing.is_none());
}
#[tokio::test]
async fn test_ws_context_send_text() {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
ctx.send_text("hello").await.unwrap();
let msg = rx.recv().await.unwrap();
assert!(matches!(msg, Message::Text(_)));
}
#[tokio::test]
async fn test_ws_context_send_binary() {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
ctx.send_binary(vec![1, 2, 3]).await.unwrap();
let msg = rx.recv().await.unwrap();
assert!(matches!(msg, Message::Binary(_)));
}
#[tokio::test]
async fn test_ws_context_send_json() {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
ctx.send_json(&serde_json::json!({"type": "test"})).await.unwrap();
let msg = rx.recv().await.unwrap();
assert!(matches!(msg, Message::Text(_)));
}
#[tokio::test]
async fn test_ws_context_rooms() {
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
let ctx = WsContext::new("conn123".to_string(), "/ws".to_string(), RoomManager::new(), tx);
ctx.join_room("chat").await.unwrap();
ctx.join_room("notifications").await.unwrap();
let rooms = ctx.get_rooms().await;
assert_eq!(rooms.len(), 2);
ctx.leave_room("chat").await.unwrap();
let rooms = ctx.get_rooms().await;
assert_eq!(rooms.len(), 1);
}
}