pub mod ai_event_generator;
pub mod handlers;
pub mod protocol_server;
pub mod ws_tracing;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::extract::{Path, State};
use axum::{response::IntoResponse, routing::get, Router};
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use mockforge_core::WsProxyHandler;
#[cfg(feature = "data-faker")]
use mockforge_data::provider::register_core_faker_provider;
use mockforge_foundation::latency::{LatencyInjector, LatencyProfile};
use mockforge_observability::get_global_registry;
use serde_json::Value;
use tokio::fs;
use tokio::time::{sleep, Duration};
use tracing::*;
pub use ai_event_generator::{AiEventGenerator, WebSocketAiConfig};
pub use ws_tracing::{
create_ws_connection_span, create_ws_message_span, record_ws_connection_success,
record_ws_error, record_ws_message_success,
};
pub use handlers::{
HandlerError, HandlerRegistry, HandlerResult, MessagePattern, MessageRouter, PassthroughConfig,
PassthroughHandler, RoomManager, WsContext, WsHandler, WsMessage,
};
pub fn router() -> Router {
#[cfg(feature = "data-faker")]
register_core_faker_provider();
Router::new().route("/ws", get(ws_handler_no_state))
}
pub fn router_with_latency(latency_injector: LatencyInjector) -> Router {
#[cfg(feature = "data-faker")]
register_core_faker_provider();
Router::new()
.route("/ws", get(ws_handler_with_state))
.with_state(latency_injector)
}
pub fn router_with_proxy(proxy_handler: WsProxyHandler) -> Router {
#[cfg(feature = "data-faker")]
register_core_faker_provider();
Router::new()
.route("/ws", get(ws_handler_with_proxy))
.route("/ws/{*path}", get(ws_handler_with_proxy_path))
.with_state(proxy_handler)
}
pub fn router_with_handlers(registry: std::sync::Arc<HandlerRegistry>) -> Router {
#[cfg(feature = "data-faker")]
register_core_faker_provider();
Router::new()
.route("/ws", get(ws_handler_with_registry))
.route("/ws/{*path}", get(ws_handler_with_registry_path))
.with_state(registry)
}
pub async fn start_with_latency(
port: u16,
latency: Option<LatencyProfile>,
) -> Result<(), Box<dyn std::error::Error>> {
start_with_latency_and_host(port, "0.0.0.0", latency).await
}
pub async fn start_with_latency_and_host(
port: u16,
host: &str,
latency: Option<LatencyProfile>,
) -> Result<(), Box<dyn std::error::Error>> {
let latency_injector = latency.map(|profile| LatencyInjector::new(profile, Default::default()));
let router = if let Some(injector) = latency_injector {
router_with_latency(injector)
} else {
router()
};
let addr: std::net::SocketAddr = format!("{}:{}", host, port).parse()?;
info!("WebSocket server listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
format!(
"Failed to bind WebSocket server to port {}: {}\n\
Hint: The port may already be in use. Try using a different port with --ws-port or check if another process is using this port with: lsof -i :{} or netstat -tulpn | grep {}",
port, e, port, port
)
})?;
axum::serve(listener, router).await?;
Ok(())
}
async fn ws_handler_no_state(ws: WebSocketUpgrade) -> impl IntoResponse {
ws.on_upgrade(handle_socket)
}
async fn ws_handler_with_state(
ws: WebSocketUpgrade,
State(_latency): State<LatencyInjector>,
) -> impl IntoResponse {
ws.on_upgrade(handle_socket)
}
async fn ws_handler_with_proxy(
ws: WebSocketUpgrade,
State(proxy): State<WsProxyHandler>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, "/ws".to_string()))
}
async fn ws_handler_with_proxy_path(
Path(path): Path<String>,
ws: WebSocketUpgrade,
State(proxy): State<WsProxyHandler>,
) -> impl IntoResponse {
let full_path = format!("/ws/{}", path);
ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, full_path))
}
async fn ws_handler_with_registry(
ws: WebSocketUpgrade,
State(registry): State<std::sync::Arc<HandlerRegistry>>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, "/ws".to_string()))
}
async fn ws_handler_with_registry_path(
Path(path): Path<String>,
ws: WebSocketUpgrade,
State(registry): State<std::sync::Arc<HandlerRegistry>>,
) -> impl IntoResponse {
let full_path = format!("/ws/{}", path);
ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, full_path))
}
async fn handle_socket(mut socket: WebSocket) {
use std::time::Instant;
let registry = get_global_registry();
let connection_start = Instant::now();
registry.record_ws_connection_established();
debug!("WebSocket connection established, tracking metrics");
let mut status = "normal";
if let Ok(replay_file) = std::env::var("MOCKFORGE_WS_REPLAY_FILE") {
info!("WebSocket replay mode enabled with file: {}", replay_file);
handle_socket_with_replay(socket, &replay_file).await;
} else {
while let Some(msg) = socket.recv().await {
match msg {
Ok(Message::Text(text)) => {
registry.record_ws_message_received();
let response = format!("echo: {}", text);
if socket.send(Message::Text(response.into())).await.is_err() {
status = "send_error";
break;
}
registry.record_ws_message_sent();
}
Ok(Message::Close(_)) => {
status = "client_close";
break;
}
Err(e) => {
error!("WebSocket error: {}", e);
registry.record_ws_error();
status = "error";
break;
}
_ => {}
}
}
}
let duration = connection_start.elapsed().as_secs_f64();
registry.record_ws_connection_closed(duration, status);
debug!("WebSocket connection closed (status: {}, duration: {:.2}s)", status, duration);
}
async fn handle_socket_with_replay(mut socket: WebSocket, replay_file: &str) {
let _registry = get_global_registry();
let content = match fs::read_to_string(replay_file).await {
Ok(content) => content,
Err(e) => {
error!("Failed to read replay file {}: {}", replay_file, e);
return;
}
};
let mut replay_entries = Vec::new();
for line in content.lines() {
if let Ok(entry) = serde_json::from_str::<Value>(line) {
replay_entries.push(entry);
}
}
info!("Loaded {} replay entries", replay_entries.len());
for entry in replay_entries {
if let Some(wait_for) = entry.get("waitFor") {
if let Some(wait_pattern) = wait_for.as_str() {
info!("Waiting for pattern: {}", wait_pattern);
let mut found = false;
while let Some(msg) = socket.recv().await {
if let Ok(Message::Text(text)) = msg {
if text.contains(wait_pattern) || wait_pattern == "^CLIENT_READY$" {
found = true;
break;
}
}
}
if !found {
break;
}
}
}
if let Some(text) = entry.get("text").and_then(|v| v.as_str()) {
let expanded_text = if std::env::var("MOCKFORGE_RESPONSE_TEMPLATE_EXPAND")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
{
expand_tokens(text)
} else {
text.to_string()
};
info!("Sending replay message: {}", expanded_text);
if socket.send(Message::Text(expanded_text.into())).await.is_err() {
break;
}
}
if let Some(ts) = entry.get("ts").and_then(|v| v.as_u64()) {
sleep(Duration::from_millis(ts * 10)).await; }
}
}
fn expand_tokens(text: &str) -> String {
let mut result = text.to_string();
result = result.replace("{{uuid}}", &uuid::Uuid::new_v4().to_string());
result = result.replace("{{now}}", &chrono::Utc::now().to_rfc3339());
if result.contains("{{now+1m}}") {
let now_plus_1m = chrono::Utc::now() + chrono::Duration::minutes(1);
result = result.replace("{{now+1m}}", &now_plus_1m.to_rfc3339());
}
if result.contains("{{now+1h}}") {
let now_plus_1h = chrono::Utc::now() + chrono::Duration::hours(1);
result = result.replace("{{now+1h}}", &now_plus_1h.to_rfc3339());
}
while result.contains("{{randInt") {
if let Some(start) = result.find("{{randInt") {
if let Some(end) = result[start..].find("}}") {
let full_match = &result[start..start + end + 2];
let content = &result[start + 9..start + end];
if let Some(space_pos) = content.find(' ') {
let min_str = &content[..space_pos];
let max_str = &content[space_pos + 1..];
if let (Ok(min), Ok(max)) = (min_str.parse::<i32>(), max_str.parse::<i32>()) {
let random_value = fastrand::i32(min..=max);
result = result.replace(full_match, &random_value.to_string());
} else {
result = result.replace(full_match, "0");
}
} else {
result = result.replace(full_match, "0");
}
} else {
break;
}
} else {
break;
}
}
result
}
async fn handle_socket_with_proxy(socket: WebSocket, proxy: WsProxyHandler, path: String) {
use std::time::Instant;
let registry = get_global_registry();
let connection_start = Instant::now();
registry.record_ws_connection_established();
let mut status = "normal";
if proxy.config.should_proxy(&path) {
info!("Proxying WebSocket connection for path: {}", path);
if let Err(e) = proxy.proxy_connection(&path, socket).await {
error!("Failed to proxy WebSocket connection: {}", e);
registry.record_ws_error();
status = "proxy_error";
}
} else {
info!("Handling WebSocket connection locally for path: {}", path);
registry.record_ws_connection_closed(0.0, ""); handle_socket(socket).await;
return; }
let duration = connection_start.elapsed().as_secs_f64();
registry.record_ws_connection_closed(duration, status);
debug!(
"Proxied WebSocket connection closed (status: {}, duration: {:.2}s)",
status, duration
);
}
async fn handle_socket_with_handlers(
socket: WebSocket,
registry: std::sync::Arc<HandlerRegistry>,
path: String,
) {
use std::time::Instant;
let metrics_registry = get_global_registry();
let connection_start = Instant::now();
metrics_registry.record_ws_connection_established();
let mut status = "normal";
let connection_id = uuid::Uuid::new_v4().to_string();
let handlers = registry.get_handlers(&path);
if handlers.is_empty() {
info!("No handlers found for path: {}, falling back to echo mode", path);
metrics_registry.record_ws_connection_closed(0.0, "");
handle_socket(socket).await;
return;
}
info!(
"Handling WebSocket connection with {} handler(s) for path: {}",
handlers.len(),
path
);
let room_manager = RoomManager::new();
let (mut socket_sender, mut socket_receiver) = socket.split();
let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
let mut ctx =
WsContext::new(connection_id.clone(), path.clone(), room_manager.clone(), message_tx);
for handler in &handlers {
if let Err(e) = handler.on_connect(&mut ctx).await {
error!("Handler on_connect error: {}", e);
status = "handler_error";
}
}
let send_task = tokio::spawn(async move {
while let Some(msg) = message_rx.recv().await {
if socket_sender.send(msg).await.is_err() {
break;
}
}
});
while let Some(msg) = socket_receiver.next().await {
match msg {
Ok(axum_msg) => {
metrics_registry.record_ws_message_received();
let ws_msg: WsMessage = axum_msg.into();
if matches!(ws_msg, WsMessage::Close) {
status = "client_close";
break;
}
for handler in &handlers {
if let Err(e) = handler.on_message(&mut ctx, ws_msg.clone()).await {
error!("Handler on_message error: {}", e);
status = "handler_error";
}
}
metrics_registry.record_ws_message_sent();
}
Err(e) => {
error!("WebSocket error: {}", e);
metrics_registry.record_ws_error();
status = "error";
break;
}
}
}
for handler in &handlers {
if let Err(e) = handler.on_disconnect(&mut ctx).await {
error!("Handler on_disconnect error: {}", e);
}
}
let _ = room_manager.leave_all(&connection_id).await;
send_task.abort();
let duration = connection_start.elapsed().as_secs_f64();
metrics_registry.record_ws_connection_closed(duration, status);
debug!(
"Handler-based WebSocket connection closed (status: {}, duration: {:.2}s)",
status, duration
);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_router_creation() {
let _router = router();
}
#[test]
fn test_router_with_latency_creation() {
let latency_profile = LatencyProfile::default();
let latency_injector = LatencyInjector::new(latency_profile, Default::default());
let _router = router_with_latency(latency_injector);
}
#[test]
fn test_router_with_proxy_creation() {
let config = mockforge_core::WsProxyConfig {
upstream_url: "ws://localhost:8080".to_string(),
..Default::default()
};
let proxy_handler = WsProxyHandler::new(config);
let _router = router_with_proxy(proxy_handler);
}
#[test]
fn test_router_with_handlers_creation() {
let registry = std::sync::Arc::new(HandlerRegistry::new());
let _router = router_with_handlers(registry);
}
#[tokio::test]
async fn test_start_with_latency_config_none() {
let result = std::panic::catch_unwind(|| {
let _router = router();
});
assert!(result.is_ok());
}
#[tokio::test]
async fn test_start_with_latency_config_some() {
let latency_profile = LatencyProfile::default();
let latency_injector = LatencyInjector::new(latency_profile, Default::default());
let result = std::panic::catch_unwind(|| {
let _router = router_with_latency(latency_injector);
});
assert!(result.is_ok());
}
#[test]
fn test_expand_tokens_uuid() {
let text = "session-{{uuid}}";
let expanded = expand_tokens(text);
assert!(!expanded.contains("{{uuid}}"));
assert!(expanded.starts_with("session-"));
let uuid_part = &expanded[8..];
assert_eq!(uuid_part.len(), 36);
}
#[test]
fn test_expand_tokens_now() {
let text = "time: {{now}}";
let expanded = expand_tokens(text);
assert!(!expanded.contains("{{now}}"));
assert!(expanded.starts_with("time: "));
assert!(expanded.contains("T"));
}
#[test]
fn test_expand_tokens_now_plus_1m() {
let text = "expires: {{now+1m}}";
let expanded = expand_tokens(text);
assert!(!expanded.contains("{{now+1m}}"));
assert!(expanded.starts_with("expires: "));
}
#[test]
fn test_expand_tokens_now_plus_1h() {
let text = "expires: {{now+1h}}";
let expanded = expand_tokens(text);
assert!(!expanded.contains("{{now+1h}}"));
assert!(expanded.starts_with("expires: "));
}
#[test]
fn test_expand_tokens_randint() {
let text = "value: {{randInt 1 100}}";
let expanded = expand_tokens(text);
assert!(!expanded.contains("{{randInt"), "Token should be expanded");
assert!(expanded.starts_with("value: "));
}
#[test]
fn test_expand_tokens_randint_multiple() {
let text = "a: {{randInt 1 10}}, b: {{randInt 20 30}}";
let expanded = expand_tokens(text);
assert!(!expanded.contains("{{randInt"));
assert!(expanded.contains("a: "));
assert!(expanded.contains("b: "));
}
#[test]
fn test_expand_tokens_mixed() {
let text = "id: {{uuid}}, time: {{now}}, rand: {{randInt 1 10}}";
let expanded = expand_tokens(text);
assert!(!expanded.contains("{{uuid}}"));
assert!(!expanded.contains("{{now}}"));
assert!(!expanded.contains("{{randInt"));
}
#[test]
fn test_expand_tokens_no_tokens() {
let text = "plain text without tokens";
let expanded = expand_tokens(text);
assert_eq!(expanded, text);
}
#[test]
fn test_latency_profile_default() {
let profile = LatencyProfile::default();
let injector = LatencyInjector::new(profile, Default::default());
let _router = router_with_latency(injector);
}
#[test]
fn test_latency_profile_with_normal_distribution() {
let profile = LatencyProfile::with_normal_distribution(100, 25.0)
.with_min_ms(50)
.with_max_ms(200);
let injector = LatencyInjector::new(profile, Default::default());
let _router = router_with_latency(injector);
}
#[test]
fn test_ws_proxy_config_default() {
let config = mockforge_core::WsProxyConfig::default();
let _url = &config.upstream_url;
}
#[test]
fn test_ws_proxy_config_custom() {
let config = mockforge_core::WsProxyConfig {
upstream_url: "wss://api.example.com/ws".to_string(),
..Default::default()
};
assert_eq!(config.upstream_url, "wss://api.example.com/ws");
}
#[test]
fn test_reexports_available() {
let _ = create_ws_connection_span("conn-123");
let _registry = HandlerRegistry::new();
let _pattern = MessagePattern::any();
}
}