1pub mod ai_event_generator;
152pub mod handlers;
153pub mod protocol_server;
155pub mod ws_tracing;
156
157use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
158use axum::extract::{Path, State};
159use axum::{response::IntoResponse, routing::get, Router};
160use futures::sink::SinkExt;
161use futures::stream::StreamExt;
162use mockforge_core::{latency::LatencyInjector, LatencyProfile, WsProxyHandler};
163#[cfg(feature = "data-faker")]
164use mockforge_data::provider::register_core_faker_provider;
165use mockforge_observability::get_global_registry;
166use serde_json::Value;
167use tokio::fs;
168use tokio::time::{sleep, Duration};
169use tracing::*;
170
171pub use ai_event_generator::{AiEventGenerator, WebSocketAiConfig};
173
174pub use ws_tracing::{
176 create_ws_connection_span, create_ws_message_span, record_ws_connection_success,
177 record_ws_error, record_ws_message_success,
178};
179
180pub use handlers::{
182 HandlerError, HandlerRegistry, HandlerResult, MessagePattern, MessageRouter, PassthroughConfig,
183 PassthroughHandler, RoomManager, WsContext, WsHandler, WsMessage,
184};
185
186pub fn router() -> Router {
188 #[cfg(feature = "data-faker")]
189 register_core_faker_provider();
190
191 Router::new().route("/ws", get(ws_handler_no_state))
192}
193
194pub fn router_with_latency(latency_injector: LatencyInjector) -> Router {
196 #[cfg(feature = "data-faker")]
197 register_core_faker_provider();
198
199 Router::new()
200 .route("/ws", get(ws_handler_with_state))
201 .with_state(latency_injector)
202}
203
204pub fn router_with_proxy(proxy_handler: WsProxyHandler) -> Router {
206 #[cfg(feature = "data-faker")]
207 register_core_faker_provider();
208
209 Router::new()
210 .route("/ws", get(ws_handler_with_proxy))
211 .route("/ws/{*path}", get(ws_handler_with_proxy_path))
212 .with_state(proxy_handler)
213}
214
215pub fn router_with_handlers(registry: std::sync::Arc<HandlerRegistry>) -> Router {
217 #[cfg(feature = "data-faker")]
218 register_core_faker_provider();
219
220 Router::new()
221 .route("/ws", get(ws_handler_with_registry))
222 .route("/ws/{*path}", get(ws_handler_with_registry_path))
223 .with_state(registry)
224}
225
226pub async fn start_with_latency(
228 port: u16,
229 latency: Option<LatencyProfile>,
230) -> Result<(), Box<dyn std::error::Error>> {
231 start_with_latency_and_host(port, "0.0.0.0", latency).await
232}
233
234pub async fn start_with_latency_and_host(
236 port: u16,
237 host: &str,
238 latency: Option<LatencyProfile>,
239) -> Result<(), Box<dyn std::error::Error>> {
240 let latency_injector = latency.map(|profile| LatencyInjector::new(profile, Default::default()));
241 let router = if let Some(injector) = latency_injector {
242 router_with_latency(injector)
243 } else {
244 router()
245 };
246
247 let addr: std::net::SocketAddr = format!("{}:{}", host, port).parse()?;
248 info!("WebSocket server listening on {}", addr);
249
250 let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
251 format!(
252 "Failed to bind WebSocket server to port {}: {}\n\
253 Hint: The port may already be in use. Try using a different port with --ws-port or check if another process is using this port with: lsof -i :{} or netstat -tulpn | grep {}",
254 port, e, port, port
255 )
256 })?;
257
258 axum::serve(listener, router).await?;
259 Ok(())
260}
261
262async fn ws_handler_no_state(ws: WebSocketUpgrade) -> impl IntoResponse {
264 ws.on_upgrade(handle_socket)
265}
266
267async fn ws_handler_with_state(
268 ws: WebSocketUpgrade,
269 State(_latency): State<LatencyInjector>,
270) -> impl IntoResponse {
271 ws.on_upgrade(handle_socket)
272}
273
274async fn ws_handler_with_proxy(
275 ws: WebSocketUpgrade,
276 State(proxy): State<WsProxyHandler>,
277) -> impl IntoResponse {
278 ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, "/ws".to_string()))
279}
280
281async fn ws_handler_with_proxy_path(
282 Path(path): Path<String>,
283 ws: WebSocketUpgrade,
284 State(proxy): State<WsProxyHandler>,
285) -> impl IntoResponse {
286 let full_path = format!("/ws/{}", path);
287 ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, full_path))
288}
289
290async fn ws_handler_with_registry(
291 ws: WebSocketUpgrade,
292 State(registry): State<std::sync::Arc<HandlerRegistry>>,
293) -> impl IntoResponse {
294 ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, "/ws".to_string()))
295}
296
297async fn ws_handler_with_registry_path(
298 Path(path): Path<String>,
299 ws: WebSocketUpgrade,
300 State(registry): State<std::sync::Arc<HandlerRegistry>>,
301) -> impl IntoResponse {
302 let full_path = format!("/ws/{}", path);
303 ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, full_path))
304}
305
306async fn handle_socket(mut socket: WebSocket) {
307 use std::time::Instant;
308
309 let registry = get_global_registry();
311 let connection_start = Instant::now();
312 registry.record_ws_connection_established();
313 debug!("WebSocket connection established, tracking metrics");
314
315 let mut status = "normal";
317
318 if let Ok(replay_file) = std::env::var("MOCKFORGE_WS_REPLAY_FILE") {
320 info!("WebSocket replay mode enabled with file: {}", replay_file);
321 handle_socket_with_replay(socket, &replay_file).await;
322 } else {
323 while let Some(msg) = socket.recv().await {
325 match msg {
326 Ok(Message::Text(text)) => {
327 registry.record_ws_message_received();
328
329 let response = format!("echo: {}", text);
331 if socket.send(Message::Text(response.into())).await.is_err() {
332 status = "send_error";
333 break;
334 }
335 registry.record_ws_message_sent();
336 }
337 Ok(Message::Close(_)) => {
338 status = "client_close";
339 break;
340 }
341 Err(e) => {
342 error!("WebSocket error: {}", e);
343 registry.record_ws_error();
344 status = "error";
345 break;
346 }
347 _ => {}
348 }
349 }
350 }
351
352 let duration = connection_start.elapsed().as_secs_f64();
354 registry.record_ws_connection_closed(duration, status);
355 debug!("WebSocket connection closed (status: {}, duration: {:.2}s)", status, duration);
356}
357
358async fn handle_socket_with_replay(mut socket: WebSocket, replay_file: &str) {
359 let _registry = get_global_registry(); let content = match fs::read_to_string(replay_file).await {
363 Ok(content) => content,
364 Err(e) => {
365 error!("Failed to read replay file {}: {}", replay_file, e);
366 return;
367 }
368 };
369
370 let mut replay_entries = Vec::new();
372 for line in content.lines() {
373 if let Ok(entry) = serde_json::from_str::<Value>(line) {
374 replay_entries.push(entry);
375 }
376 }
377
378 info!("Loaded {} replay entries", replay_entries.len());
379
380 for entry in replay_entries {
382 if let Some(wait_for) = entry.get("waitFor") {
384 if let Some(wait_pattern) = wait_for.as_str() {
385 info!("Waiting for pattern: {}", wait_pattern);
386 let mut found = false;
388 while let Some(msg) = socket.recv().await {
389 if let Ok(Message::Text(text)) = msg {
390 if text.contains(wait_pattern) || wait_pattern == "^CLIENT_READY$" {
391 found = true;
392 break;
393 }
394 }
395 }
396 if !found {
397 break;
398 }
399 }
400 }
401
402 if let Some(text) = entry.get("text").and_then(|v| v.as_str()) {
404 let expanded_text = if std::env::var("MOCKFORGE_RESPONSE_TEMPLATE_EXPAND")
406 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
407 .unwrap_or(false)
408 {
409 expand_tokens(text)
410 } else {
411 text.to_string()
412 };
413
414 info!("Sending replay message: {}", expanded_text);
415 if socket.send(Message::Text(expanded_text.into())).await.is_err() {
416 break;
417 }
418 }
419
420 if let Some(ts) = entry.get("ts").and_then(|v| v.as_u64()) {
422 sleep(Duration::from_millis(ts * 10)).await; }
424 }
425}
426
427fn expand_tokens(text: &str) -> String {
428 let mut result = text.to_string();
429
430 result = result.replace("{{uuid}}", &uuid::Uuid::new_v4().to_string());
432
433 result = result.replace("{{now}}", &chrono::Utc::now().to_rfc3339());
435
436 if result.contains("{{now+1m}}") {
438 let now_plus_1m = chrono::Utc::now() + chrono::Duration::minutes(1);
439 result = result.replace("{{now+1m}}", &now_plus_1m.to_rfc3339());
440 }
441
442 if result.contains("{{now+1h}}") {
444 let now_plus_1h = chrono::Utc::now() + chrono::Duration::hours(1);
445 result = result.replace("{{now+1h}}", &now_plus_1h.to_rfc3339());
446 }
447
448 while result.contains("{{randInt") {
450 if let Some(start) = result.find("{{randInt") {
451 if let Some(end) = result[start..].find("}}") {
452 let full_match = &result[start..start + end + 2];
453 let content = &result[start + 9..start + end]; if let Some(space_pos) = content.find(' ') {
456 let min_str = &content[..space_pos];
457 let max_str = &content[space_pos + 1..];
458
459 if let (Ok(min), Ok(max)) = (min_str.parse::<i32>(), max_str.parse::<i32>()) {
460 let random_value = fastrand::i32(min..=max);
461 result = result.replace(full_match, &random_value.to_string());
462 } else {
463 result = result.replace(full_match, "0");
464 }
465 } else {
466 result = result.replace(full_match, "0");
467 }
468 } else {
469 break;
470 }
471 } else {
472 break;
473 }
474 }
475
476 result
477}
478
479async fn handle_socket_with_proxy(socket: WebSocket, proxy: WsProxyHandler, path: String) {
480 use std::time::Instant;
481
482 let registry = get_global_registry();
483 let connection_start = Instant::now();
484 registry.record_ws_connection_established();
485
486 let mut status = "normal";
487
488 if proxy.config.should_proxy(&path) {
490 info!("Proxying WebSocket connection for path: {}", path);
491 if let Err(e) = proxy.proxy_connection(&path, socket).await {
492 error!("Failed to proxy WebSocket connection: {}", e);
493 registry.record_ws_error();
494 status = "proxy_error";
495 }
496 } else {
497 info!("Handling WebSocket connection locally for path: {}", path);
498 registry.record_ws_connection_closed(0.0, ""); handle_socket(socket).await;
503 return; }
505
506 let duration = connection_start.elapsed().as_secs_f64();
507 registry.record_ws_connection_closed(duration, status);
508 debug!(
509 "Proxied WebSocket connection closed (status: {}, duration: {:.2}s)",
510 status, duration
511 );
512}
513
514async fn handle_socket_with_handlers(
515 socket: WebSocket,
516 registry: std::sync::Arc<HandlerRegistry>,
517 path: String,
518) {
519 use std::time::Instant;
520
521 let metrics_registry = get_global_registry();
522 let connection_start = Instant::now();
523 metrics_registry.record_ws_connection_established();
524
525 let mut status = "normal";
526
527 let connection_id = uuid::Uuid::new_v4().to_string();
529
530 let handlers = registry.get_handlers(&path);
532 if handlers.is_empty() {
533 info!("No handlers found for path: {}, falling back to echo mode", path);
534 metrics_registry.record_ws_connection_closed(0.0, "");
535 handle_socket(socket).await;
536 return;
537 }
538
539 info!(
540 "Handling WebSocket connection with {} handler(s) for path: {}",
541 handlers.len(),
542 path
543 );
544
545 let room_manager = RoomManager::new();
547
548 let (mut socket_sender, mut socket_receiver) = socket.split();
550
551 let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
553
554 let mut ctx =
556 WsContext::new(connection_id.clone(), path.clone(), room_manager.clone(), message_tx);
557
558 for handler in &handlers {
560 if let Err(e) = handler.on_connect(&mut ctx).await {
561 error!("Handler on_connect error: {}", e);
562 status = "handler_error";
563 }
564 }
565
566 let send_task = tokio::spawn(async move {
568 while let Some(msg) = message_rx.recv().await {
569 if socket_sender.send(msg).await.is_err() {
570 break;
571 }
572 }
573 });
574
575 while let Some(msg) = socket_receiver.next().await {
577 match msg {
578 Ok(axum_msg) => {
579 metrics_registry.record_ws_message_received();
580
581 let ws_msg: WsMessage = axum_msg.into();
582
583 if matches!(ws_msg, WsMessage::Close) {
585 status = "client_close";
586 break;
587 }
588
589 for handler in &handlers {
591 if let Err(e) = handler.on_message(&mut ctx, ws_msg.clone()).await {
592 error!("Handler on_message error: {}", e);
593 status = "handler_error";
594 }
595 }
596
597 metrics_registry.record_ws_message_sent();
598 }
599 Err(e) => {
600 error!("WebSocket error: {}", e);
601 metrics_registry.record_ws_error();
602 status = "error";
603 break;
604 }
605 }
606 }
607
608 for handler in &handlers {
610 if let Err(e) = handler.on_disconnect(&mut ctx).await {
611 error!("Handler on_disconnect error: {}", e);
612 }
613 }
614
615 let _ = room_manager.leave_all(&connection_id).await;
617
618 send_task.abort();
620
621 let duration = connection_start.elapsed().as_secs_f64();
622 metrics_registry.record_ws_connection_closed(duration, status);
623 debug!(
624 "Handler-based WebSocket connection closed (status: {}, duration: {:.2}s)",
625 status, duration
626 );
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
636 fn test_router_creation() {
637 let _router = router();
638 }
640
641 #[test]
642 fn test_router_with_latency_creation() {
643 let latency_profile = LatencyProfile::default();
644 let latency_injector = LatencyInjector::new(latency_profile, Default::default());
645 let _router = router_with_latency(latency_injector);
646 }
648
649 #[test]
650 fn test_router_with_proxy_creation() {
651 let config = mockforge_core::WsProxyConfig {
652 upstream_url: "ws://localhost:8080".to_string(),
653 ..Default::default()
654 };
655 let proxy_handler = WsProxyHandler::new(config);
656 let _router = router_with_proxy(proxy_handler);
657 }
659
660 #[test]
661 fn test_router_with_handlers_creation() {
662 let registry = std::sync::Arc::new(HandlerRegistry::new());
663 let _router = router_with_handlers(registry);
664 }
666
667 #[tokio::test]
668 async fn test_start_with_latency_config_none() {
669 let result = std::panic::catch_unwind(|| {
671 let _router = router();
672 });
673 assert!(result.is_ok());
674 }
675
676 #[tokio::test]
677 async fn test_start_with_latency_config_some() {
678 let latency_profile = LatencyProfile::default();
680 let latency_injector = LatencyInjector::new(latency_profile, Default::default());
681
682 let result = std::panic::catch_unwind(|| {
683 let _router = router_with_latency(latency_injector);
684 });
685 assert!(result.is_ok());
686 }
687
688 #[test]
691 fn test_expand_tokens_uuid() {
692 let text = "session-{{uuid}}";
693 let expanded = expand_tokens(text);
694 assert!(!expanded.contains("{{uuid}}"));
695 assert!(expanded.starts_with("session-"));
696 let uuid_part = &expanded[8..];
698 assert_eq!(uuid_part.len(), 36);
699 }
700
701 #[test]
702 fn test_expand_tokens_now() {
703 let text = "time: {{now}}";
704 let expanded = expand_tokens(text);
705 assert!(!expanded.contains("{{now}}"));
706 assert!(expanded.starts_with("time: "));
707 assert!(expanded.contains("T"));
709 }
710
711 #[test]
712 fn test_expand_tokens_now_plus_1m() {
713 let text = "expires: {{now+1m}}";
714 let expanded = expand_tokens(text);
715 assert!(!expanded.contains("{{now+1m}}"));
716 assert!(expanded.starts_with("expires: "));
717 }
718
719 #[test]
720 fn test_expand_tokens_now_plus_1h() {
721 let text = "expires: {{now+1h}}";
722 let expanded = expand_tokens(text);
723 assert!(!expanded.contains("{{now+1h}}"));
724 assert!(expanded.starts_with("expires: "));
725 }
726
727 #[test]
728 fn test_expand_tokens_randint() {
729 let text = "value: {{randInt 1 100}}";
730 let expanded = expand_tokens(text);
731 assert!(!expanded.contains("{{randInt"), "Token should be expanded");
732 assert!(expanded.starts_with("value: "));
733 }
735
736 #[test]
737 fn test_expand_tokens_randint_multiple() {
738 let text = "a: {{randInt 1 10}}, b: {{randInt 20 30}}";
739 let expanded = expand_tokens(text);
740 assert!(!expanded.contains("{{randInt"));
741 assert!(expanded.contains("a: "));
742 assert!(expanded.contains("b: "));
743 }
744
745 #[test]
746 fn test_expand_tokens_mixed() {
747 let text = "id: {{uuid}}, time: {{now}}, rand: {{randInt 1 10}}";
748 let expanded = expand_tokens(text);
749 assert!(!expanded.contains("{{uuid}}"));
750 assert!(!expanded.contains("{{now}}"));
751 assert!(!expanded.contains("{{randInt"));
752 }
753
754 #[test]
755 fn test_expand_tokens_no_tokens() {
756 let text = "plain text without tokens";
757 let expanded = expand_tokens(text);
758 assert_eq!(expanded, text);
759 }
760
761 #[test]
764 fn test_latency_profile_default() {
765 let profile = LatencyProfile::default();
766 let injector = LatencyInjector::new(profile, Default::default());
768 let _router = router_with_latency(injector);
769 }
770
771 #[test]
772 fn test_latency_profile_with_normal_distribution() {
773 let profile = LatencyProfile::with_normal_distribution(100, 25.0)
774 .with_min_ms(50)
775 .with_max_ms(200);
776 let injector = LatencyInjector::new(profile, Default::default());
777 let _router = router_with_latency(injector);
778 }
779
780 #[test]
783 fn test_ws_proxy_config_default() {
784 let config = mockforge_core::WsProxyConfig::default();
785 let _url = &config.upstream_url;
787 }
788
789 #[test]
790 fn test_ws_proxy_config_custom() {
791 let config = mockforge_core::WsProxyConfig {
792 upstream_url: "wss://api.example.com/ws".to_string(),
793 ..Default::default()
794 };
795 assert_eq!(config.upstream_url, "wss://api.example.com/ws");
796 }
797
798 #[test]
801 fn test_reexports_available() {
802 let _ = create_ws_connection_span("conn-123");
804
805 let _registry = HandlerRegistry::new();
807 let _pattern = MessagePattern::any();
808 }
809}