1pub mod ai_event_generator;
152pub mod handlers;
153pub mod ws_tracing;
154
155use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
156use axum::extract::{Path, State};
157use axum::{response::IntoResponse, routing::get, Router};
158use futures::sink::SinkExt;
159use futures::stream::StreamExt;
160use mockforge_core::{latency::LatencyInjector, LatencyProfile, WsProxyHandler};
161#[cfg(feature = "data-faker")]
162use mockforge_data::provider::register_core_faker_provider;
163use mockforge_observability::get_global_registry;
164use serde_json::Value;
165use tokio::fs;
166use tokio::time::{sleep, Duration};
167use tracing::*;
168
169pub use ai_event_generator::{AiEventGenerator, WebSocketAiConfig};
171
172pub use ws_tracing::{
174 create_ws_connection_span, create_ws_message_span, record_ws_connection_success,
175 record_ws_error, record_ws_message_success,
176};
177
178pub use handlers::{
180 HandlerError, HandlerRegistry, HandlerResult, MessagePattern, MessageRouter, PassthroughConfig,
181 PassthroughHandler, RoomManager, WsContext, WsHandler, WsMessage,
182};
183
184pub fn router() -> Router {
186 #[cfg(feature = "data-faker")]
187 register_core_faker_provider();
188
189 Router::new().route("/ws", get(ws_handler_no_state))
190}
191
192pub fn router_with_latency(latency_injector: LatencyInjector) -> Router {
194 #[cfg(feature = "data-faker")]
195 register_core_faker_provider();
196
197 Router::new()
198 .route("/ws", get(ws_handler_with_state))
199 .with_state(latency_injector)
200}
201
202pub fn router_with_proxy(proxy_handler: WsProxyHandler) -> Router {
204 #[cfg(feature = "data-faker")]
205 register_core_faker_provider();
206
207 Router::new()
208 .route("/ws", get(ws_handler_with_proxy))
209 .route("/ws/{*path}", get(ws_handler_with_proxy_path))
210 .with_state(proxy_handler)
211}
212
213pub fn router_with_handlers(registry: std::sync::Arc<HandlerRegistry>) -> Router {
215 #[cfg(feature = "data-faker")]
216 register_core_faker_provider();
217
218 Router::new()
219 .route("/ws", get(ws_handler_with_registry))
220 .route("/ws/{*path}", get(ws_handler_with_registry_path))
221 .with_state(registry)
222}
223
224pub async fn start_with_latency(
226 port: u16,
227 latency: Option<LatencyProfile>,
228) -> Result<(), Box<dyn std::error::Error>> {
229 let latency_injector = latency.map(|profile| LatencyInjector::new(profile, Default::default()));
230 let router = if let Some(injector) = latency_injector {
231 router_with_latency(injector)
232 } else {
233 router()
234 };
235
236 let addr: std::net::SocketAddr = format!("127.0.0.1:{}", port).parse()?;
237 info!("WebSocket server listening on {}", addr);
238
239 let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
240 format!(
241 "Failed to bind WebSocket server to port {}: {}\n\
242 Hint: The port may already be in use. Try using a different port with --ws-port or check if another process is using this port with: lsof -i :{} or netstat -tulpn | grep {}",
243 port, e, port, port
244 )
245 })?;
246
247 axum::serve(listener, router).await?;
248 Ok(())
249}
250
251async fn ws_handler_no_state(ws: WebSocketUpgrade) -> impl IntoResponse {
253 ws.on_upgrade(handle_socket)
254}
255
256async fn ws_handler_with_state(
257 ws: WebSocketUpgrade,
258 axum::extract::State(_latency): axum::extract::State<LatencyInjector>,
259) -> impl IntoResponse {
260 ws.on_upgrade(handle_socket)
261}
262
263async fn ws_handler_with_proxy(
264 ws: WebSocketUpgrade,
265 State(proxy): State<WsProxyHandler>,
266) -> impl IntoResponse {
267 ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, "/ws".to_string()))
268}
269
270async fn ws_handler_with_proxy_path(
271 Path(path): Path<String>,
272 ws: WebSocketUpgrade,
273 State(proxy): State<WsProxyHandler>,
274) -> impl IntoResponse {
275 let full_path = format!("/ws/{}", path);
276 ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, full_path))
277}
278
279async fn ws_handler_with_registry(
280 ws: WebSocketUpgrade,
281 State(registry): State<std::sync::Arc<HandlerRegistry>>,
282) -> impl IntoResponse {
283 ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, "/ws".to_string()))
284}
285
286async fn ws_handler_with_registry_path(
287 Path(path): Path<String>,
288 ws: WebSocketUpgrade,
289 State(registry): State<std::sync::Arc<HandlerRegistry>>,
290) -> impl IntoResponse {
291 let full_path = format!("/ws/{}", path);
292 ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, full_path))
293}
294
295async fn handle_socket(mut socket: WebSocket) {
296 use std::time::Instant;
297
298 let registry = get_global_registry();
300 let connection_start = Instant::now();
301 registry.record_ws_connection_established();
302 debug!("WebSocket connection established, tracking metrics");
303
304 let mut status = "normal";
306
307 if let Ok(replay_file) = std::env::var("MOCKFORGE_WS_REPLAY_FILE") {
309 info!("WebSocket replay mode enabled with file: {}", replay_file);
310 handle_socket_with_replay(socket, &replay_file).await;
311 } else {
312 while let Some(msg) = socket.recv().await {
314 match msg {
315 Ok(Message::Text(text)) => {
316 registry.record_ws_message_received();
317
318 let response = format!("echo: {}", text);
320 if socket.send(Message::Text(response.into())).await.is_err() {
321 status = "send_error";
322 break;
323 }
324 registry.record_ws_message_sent();
325 }
326 Ok(Message::Close(_)) => {
327 status = "client_close";
328 break;
329 }
330 Err(e) => {
331 error!("WebSocket error: {}", e);
332 registry.record_ws_error();
333 status = "error";
334 break;
335 }
336 _ => {}
337 }
338 }
339 }
340
341 let duration = connection_start.elapsed().as_secs_f64();
343 registry.record_ws_connection_closed(duration, status);
344 debug!("WebSocket connection closed (status: {}, duration: {:.2}s)", status, duration);
345}
346
347async fn handle_socket_with_replay(mut socket: WebSocket, replay_file: &str) {
348 let _registry = get_global_registry(); let content = match fs::read_to_string(replay_file).await {
352 Ok(content) => content,
353 Err(e) => {
354 error!("Failed to read replay file {}: {}", replay_file, e);
355 return;
356 }
357 };
358
359 let mut replay_entries = Vec::new();
361 for line in content.lines() {
362 if let Ok(entry) = serde_json::from_str::<Value>(line) {
363 replay_entries.push(entry);
364 }
365 }
366
367 info!("Loaded {} replay entries", replay_entries.len());
368
369 for entry in replay_entries {
371 if let Some(wait_for) = entry.get("waitFor") {
373 if let Some(wait_pattern) = wait_for.as_str() {
374 info!("Waiting for pattern: {}", wait_pattern);
375 let mut found = false;
377 while let Some(msg) = socket.recv().await {
378 if let Ok(Message::Text(text)) = msg {
379 if text.contains(wait_pattern) || wait_pattern == "^CLIENT_READY$" {
380 found = true;
381 break;
382 }
383 }
384 }
385 if !found {
386 break;
387 }
388 }
389 }
390
391 if let Some(text) = entry.get("text").and_then(|v| v.as_str()) {
393 let expanded_text = if std::env::var("MOCKFORGE_RESPONSE_TEMPLATE_EXPAND")
395 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
396 .unwrap_or(false)
397 {
398 expand_tokens(text)
399 } else {
400 text.to_string()
401 };
402
403 info!("Sending replay message: {}", expanded_text);
404 if socket.send(Message::Text(expanded_text.into())).await.is_err() {
405 break;
406 }
407 }
408
409 if let Some(ts) = entry.get("ts").and_then(|v| v.as_u64()) {
411 sleep(Duration::from_millis(ts * 10)).await; }
413 }
414}
415
416fn expand_tokens(text: &str) -> String {
417 let mut result = text.to_string();
418
419 result = result.replace("{{uuid}}", &uuid::Uuid::new_v4().to_string());
421
422 result = result.replace("{{now}}", &chrono::Utc::now().to_rfc3339());
424
425 if result.contains("{{now+1m}}") {
427 let now_plus_1m = chrono::Utc::now() + chrono::Duration::minutes(1);
428 result = result.replace("{{now+1m}}", &now_plus_1m.to_rfc3339());
429 }
430
431 if result.contains("{{now+1h}}") {
433 let now_plus_1h = chrono::Utc::now() + chrono::Duration::hours(1);
434 result = result.replace("{{now+1h}}", &now_plus_1h.to_rfc3339());
435 }
436
437 while result.contains("{{randInt") {
439 if let Some(start) = result.find("{{randInt") {
440 if let Some(end) = result[start..].find("}}") {
441 let full_match = &result[start..start + end + 2];
442 let content = &result[start + 9..start + end]; if let Some(space_pos) = content.find(' ') {
445 let min_str = &content[..space_pos];
446 let max_str = &content[space_pos + 1..];
447
448 if let (Ok(min), Ok(max)) = (min_str.parse::<i32>(), max_str.parse::<i32>()) {
449 let random_value = fastrand::i32(min..=max);
450 result = result.replace(full_match, &random_value.to_string());
451 } else {
452 result = result.replace(full_match, "0");
453 }
454 } else {
455 result = result.replace(full_match, "0");
456 }
457 } else {
458 break;
459 }
460 } else {
461 break;
462 }
463 }
464
465 result
466}
467
468async fn handle_socket_with_proxy(socket: WebSocket, proxy: WsProxyHandler, path: String) {
469 use std::time::Instant;
470
471 let registry = get_global_registry();
472 let connection_start = Instant::now();
473 registry.record_ws_connection_established();
474
475 let mut status = "normal";
476
477 if proxy.config.should_proxy(&path) {
479 info!("Proxying WebSocket connection for path: {}", path);
480 if let Err(e) = proxy.proxy_connection(&path, socket).await {
481 error!("Failed to proxy WebSocket connection: {}", e);
482 registry.record_ws_error();
483 status = "proxy_error";
484 }
485 } else {
486 info!("Handling WebSocket connection locally for path: {}", path);
487 registry.record_ws_connection_closed(0.0, ""); handle_socket(socket).await;
492 return; }
494
495 let duration = connection_start.elapsed().as_secs_f64();
496 registry.record_ws_connection_closed(duration, status);
497 debug!(
498 "Proxied WebSocket connection closed (status: {}, duration: {:.2}s)",
499 status, duration
500 );
501}
502
503async fn handle_socket_with_handlers(
504 socket: WebSocket,
505 registry: std::sync::Arc<HandlerRegistry>,
506 path: String,
507) {
508 use std::time::Instant;
509
510 let metrics_registry = get_global_registry();
511 let connection_start = Instant::now();
512 metrics_registry.record_ws_connection_established();
513
514 let mut status = "normal";
515
516 let connection_id = uuid::Uuid::new_v4().to_string();
518
519 let handlers = registry.get_handlers(&path);
521 if handlers.is_empty() {
522 info!("No handlers found for path: {}, falling back to echo mode", path);
523 metrics_registry.record_ws_connection_closed(0.0, "");
524 handle_socket(socket).await;
525 return;
526 }
527
528 info!(
529 "Handling WebSocket connection with {} handler(s) for path: {}",
530 handlers.len(),
531 path
532 );
533
534 let room_manager = RoomManager::new();
536
537 let (mut socket_sender, mut socket_receiver) = socket.split();
539
540 let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
542
543 let mut ctx =
545 WsContext::new(connection_id.clone(), path.clone(), room_manager.clone(), message_tx);
546
547 for handler in &handlers {
549 if let Err(e) = handler.on_connect(&mut ctx).await {
550 error!("Handler on_connect error: {}", e);
551 status = "handler_error";
552 }
553 }
554
555 let send_task = tokio::spawn(async move {
557 while let Some(msg) = message_rx.recv().await {
558 if socket_sender.send(msg).await.is_err() {
559 break;
560 }
561 }
562 });
563
564 while let Some(msg) = socket_receiver.next().await {
566 match msg {
567 Ok(axum_msg) => {
568 metrics_registry.record_ws_message_received();
569
570 let ws_msg: WsMessage = axum_msg.into();
571
572 if matches!(ws_msg, WsMessage::Close) {
574 status = "client_close";
575 break;
576 }
577
578 for handler in &handlers {
580 if let Err(e) = handler.on_message(&mut ctx, ws_msg.clone()).await {
581 error!("Handler on_message error: {}", e);
582 status = "handler_error";
583 }
584 }
585
586 metrics_registry.record_ws_message_sent();
587 }
588 Err(e) => {
589 error!("WebSocket error: {}", e);
590 metrics_registry.record_ws_error();
591 status = "error";
592 break;
593 }
594 }
595 }
596
597 for handler in &handlers {
599 if let Err(e) = handler.on_disconnect(&mut ctx).await {
600 error!("Handler on_disconnect error: {}", e);
601 }
602 }
603
604 let _ = room_manager.leave_all(&connection_id).await;
606
607 send_task.abort();
609
610 let duration = connection_start.elapsed().as_secs_f64();
611 metrics_registry.record_ws_connection_closed(duration, status);
612 debug!(
613 "Handler-based WebSocket connection closed (status: {}, duration: {:.2}s)",
614 status, duration
615 );
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 #[test]
623 fn test_router_creation() {
624 let _router = router();
625 }
627
628 #[test]
629 fn test_router_with_latency_creation() {
630 let latency_profile = LatencyProfile::default();
631 let latency_injector = LatencyInjector::new(latency_profile, Default::default());
632 let _router = router_with_latency(latency_injector);
633 }
635
636 #[test]
637 fn test_router_with_proxy_creation() {
638 let config = mockforge_core::WsProxyConfig {
639 upstream_url: "ws://localhost:8080".to_string(),
640 ..Default::default()
641 };
642 let proxy_handler = WsProxyHandler::new(config);
643 let _router = router_with_proxy(proxy_handler);
644 }
646
647 #[tokio::test]
648 async fn test_start_with_latency_config_none() {
649 let result = std::panic::catch_unwind(|| {
651 let _router = router();
652 });
653 assert!(result.is_ok());
654 }
655
656 #[tokio::test]
657 async fn test_start_with_latency_config_some() {
658 let latency_profile = LatencyProfile::default();
660 let latency_injector = LatencyInjector::new(latency_profile, Default::default());
661
662 let result = std::panic::catch_unwind(|| {
663 let _router = router_with_latency(latency_injector);
664 });
665 assert!(result.is_ok());
666 }
667}