1pub mod ai_event_generator;
152pub mod handlers;
153pub mod protocol_server;
155pub mod ws_tracing;
156
157use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
158use axum::extract::{Path, State};
159use axum::{response::IntoResponse, routing::get, Router};
160use futures::sink::SinkExt;
161use futures::stream::StreamExt;
162use mockforge_core::WsProxyHandler;
163#[cfg(feature = "data-faker")]
164use mockforge_data::provider::register_core_faker_provider;
165use mockforge_foundation::latency::{LatencyInjector, LatencyProfile};
166use mockforge_observability::get_global_registry;
167use serde_json::Value;
168use tokio::fs;
169use tokio::time::{sleep, Duration};
170use tracing::*;
171
172pub use ai_event_generator::{AiEventGenerator, WebSocketAiConfig};
174
175pub use ws_tracing::{
177 create_ws_connection_span, create_ws_message_span, record_ws_connection_success,
178 record_ws_error, record_ws_message_success,
179};
180
181pub use handlers::{
183 HandlerError, HandlerRegistry, HandlerResult, MessagePattern, MessageRouter, PassthroughConfig,
184 PassthroughHandler, RoomManager, WsContext, WsHandler, WsMessage,
185};
186
187pub fn router() -> Router {
189 #[cfg(feature = "data-faker")]
190 register_core_faker_provider();
191
192 Router::new().route("/ws", get(ws_handler_no_state))
193}
194
195pub fn router_with_latency(latency_injector: LatencyInjector) -> Router {
197 #[cfg(feature = "data-faker")]
198 register_core_faker_provider();
199
200 Router::new()
201 .route("/ws", get(ws_handler_with_state))
202 .with_state(latency_injector)
203}
204
205pub fn router_with_proxy(proxy_handler: WsProxyHandler) -> Router {
207 #[cfg(feature = "data-faker")]
208 register_core_faker_provider();
209
210 Router::new()
211 .route("/ws", get(ws_handler_with_proxy))
212 .route("/ws/{*path}", get(ws_handler_with_proxy_path))
213 .with_state(proxy_handler)
214}
215
216pub fn router_with_handlers(registry: std::sync::Arc<HandlerRegistry>) -> Router {
218 #[cfg(feature = "data-faker")]
219 register_core_faker_provider();
220
221 Router::new()
222 .route("/ws", get(ws_handler_with_registry))
223 .route("/ws/{*path}", get(ws_handler_with_registry_path))
224 .with_state(registry)
225}
226
227pub async fn start_with_latency(
229 port: u16,
230 latency: Option<LatencyProfile>,
231) -> Result<(), Box<dyn std::error::Error>> {
232 start_with_latency_and_host(port, "0.0.0.0", latency).await
233}
234
235pub async fn start_with_latency_and_host(
237 port: u16,
238 host: &str,
239 latency: Option<LatencyProfile>,
240) -> Result<(), Box<dyn std::error::Error>> {
241 let latency_injector = latency.map(|profile| LatencyInjector::new(profile, Default::default()));
242 let router = if let Some(injector) = latency_injector {
243 router_with_latency(injector)
244 } else {
245 router()
246 };
247
248 let addr: std::net::SocketAddr = format!("{}:{}", host, port).parse()?;
249 info!("WebSocket server listening on {}", addr);
250
251 let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
252 format!(
253 "Failed to bind WebSocket server to port {}: {}\n\
254 Hint: The port may already be in use. Try using a different port with --ws-port or check if another process is using this port with: lsof -i :{} or netstat -tulpn | grep {}",
255 port, e, port, port
256 )
257 })?;
258
259 axum::serve(listener, router).await?;
260 Ok(())
261}
262
263async fn ws_handler_no_state(ws: WebSocketUpgrade) -> impl IntoResponse {
265 ws.on_upgrade(handle_socket)
266}
267
268async fn ws_handler_with_state(
269 ws: WebSocketUpgrade,
270 State(_latency): State<LatencyInjector>,
271) -> impl IntoResponse {
272 ws.on_upgrade(handle_socket)
273}
274
275async fn ws_handler_with_proxy(
276 ws: WebSocketUpgrade,
277 State(proxy): State<WsProxyHandler>,
278) -> impl IntoResponse {
279 ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, "/ws".to_string()))
280}
281
282async fn ws_handler_with_proxy_path(
283 Path(path): Path<String>,
284 ws: WebSocketUpgrade,
285 State(proxy): State<WsProxyHandler>,
286) -> impl IntoResponse {
287 let full_path = format!("/ws/{}", path);
288 ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, full_path))
289}
290
291async fn ws_handler_with_registry(
292 ws: WebSocketUpgrade,
293 State(registry): State<std::sync::Arc<HandlerRegistry>>,
294) -> impl IntoResponse {
295 ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, "/ws".to_string()))
296}
297
298async fn ws_handler_with_registry_path(
299 Path(path): Path<String>,
300 ws: WebSocketUpgrade,
301 State(registry): State<std::sync::Arc<HandlerRegistry>>,
302) -> impl IntoResponse {
303 let full_path = format!("/ws/{}", path);
304 ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, full_path))
305}
306
307async fn handle_socket(mut socket: WebSocket) {
308 use std::time::Instant;
309
310 let registry = get_global_registry();
312 let connection_start = Instant::now();
313 registry.record_ws_connection_established();
314 debug!("WebSocket connection established, tracking metrics");
315
316 let mut status = "normal";
318
319 if let Ok(replay_file) = std::env::var("MOCKFORGE_WS_REPLAY_FILE") {
321 info!("WebSocket replay mode enabled with file: {}", replay_file);
322 handle_socket_with_replay(socket, &replay_file).await;
323 } else {
324 while let Some(msg) = socket.recv().await {
326 match msg {
327 Ok(Message::Text(text)) => {
328 registry.record_ws_message_received();
329
330 let response = format!("echo: {}", text);
332 if socket.send(Message::Text(response.into())).await.is_err() {
333 status = "send_error";
334 break;
335 }
336 registry.record_ws_message_sent();
337 }
338 Ok(Message::Close(_)) => {
339 status = "client_close";
340 break;
341 }
342 Err(e) => {
343 error!("WebSocket error: {}", e);
344 registry.record_ws_error();
345 status = "error";
346 break;
347 }
348 _ => {}
349 }
350 }
351 }
352
353 let duration = connection_start.elapsed().as_secs_f64();
355 registry.record_ws_connection_closed(duration, status);
356 debug!("WebSocket connection closed (status: {}, duration: {:.2}s)", status, duration);
357}
358
359async fn handle_socket_with_replay(mut socket: WebSocket, replay_file: &str) {
360 let _registry = get_global_registry(); let content = match fs::read_to_string(replay_file).await {
364 Ok(content) => content,
365 Err(e) => {
366 error!("Failed to read replay file {}: {}", replay_file, e);
367 return;
368 }
369 };
370
371 let mut replay_entries = Vec::new();
373 for line in content.lines() {
374 if let Ok(entry) = serde_json::from_str::<Value>(line) {
375 replay_entries.push(entry);
376 }
377 }
378
379 info!("Loaded {} replay entries", replay_entries.len());
380
381 for entry in replay_entries {
383 if let Some(wait_for) = entry.get("waitFor") {
385 if let Some(wait_pattern) = wait_for.as_str() {
386 info!("Waiting for pattern: {}", wait_pattern);
387 let mut found = false;
389 while let Some(msg) = socket.recv().await {
390 if let Ok(Message::Text(text)) = msg {
391 if text.contains(wait_pattern) || wait_pattern == "^CLIENT_READY$" {
392 found = true;
393 break;
394 }
395 }
396 }
397 if !found {
398 break;
399 }
400 }
401 }
402
403 if let Some(text) = entry.get("text").and_then(|v| v.as_str()) {
405 let expanded_text = if std::env::var("MOCKFORGE_RESPONSE_TEMPLATE_EXPAND")
407 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
408 .unwrap_or(false)
409 {
410 expand_tokens(text)
411 } else {
412 text.to_string()
413 };
414
415 info!("Sending replay message: {}", expanded_text);
416 if socket.send(Message::Text(expanded_text.into())).await.is_err() {
417 break;
418 }
419 }
420
421 if let Some(ts) = entry.get("ts").and_then(|v| v.as_u64()) {
423 sleep(Duration::from_millis(ts * 10)).await; }
425 }
426}
427
428fn expand_tokens(text: &str) -> String {
429 let mut result = text.to_string();
430
431 result = result.replace("{{uuid}}", &uuid::Uuid::new_v4().to_string());
433
434 result = result.replace("{{now}}", &chrono::Utc::now().to_rfc3339());
436
437 if result.contains("{{now+1m}}") {
439 let now_plus_1m = chrono::Utc::now() + chrono::Duration::minutes(1);
440 result = result.replace("{{now+1m}}", &now_plus_1m.to_rfc3339());
441 }
442
443 if result.contains("{{now+1h}}") {
445 let now_plus_1h = chrono::Utc::now() + chrono::Duration::hours(1);
446 result = result.replace("{{now+1h}}", &now_plus_1h.to_rfc3339());
447 }
448
449 while result.contains("{{randInt") {
451 if let Some(start) = result.find("{{randInt") {
452 if let Some(end) = result[start..].find("}}") {
453 let full_match = &result[start..start + end + 2];
454 let content = &result[start + 9..start + end]; if let Some(space_pos) = content.find(' ') {
457 let min_str = &content[..space_pos];
458 let max_str = &content[space_pos + 1..];
459
460 if let (Ok(min), Ok(max)) = (min_str.parse::<i32>(), max_str.parse::<i32>()) {
461 let random_value = fastrand::i32(min..=max);
462 result = result.replace(full_match, &random_value.to_string());
463 } else {
464 result = result.replace(full_match, "0");
465 }
466 } else {
467 result = result.replace(full_match, "0");
468 }
469 } else {
470 break;
471 }
472 } else {
473 break;
474 }
475 }
476
477 result
478}
479
480async fn handle_socket_with_proxy(socket: WebSocket, proxy: WsProxyHandler, path: String) {
481 use std::time::Instant;
482
483 let registry = get_global_registry();
484 let connection_start = Instant::now();
485 registry.record_ws_connection_established();
486
487 let mut status = "normal";
488
489 if proxy.config.should_proxy(&path) {
491 info!("Proxying WebSocket connection for path: {}", path);
492 if let Err(e) = proxy.proxy_connection(&path, socket).await {
493 error!("Failed to proxy WebSocket connection: {}", e);
494 registry.record_ws_error();
495 status = "proxy_error";
496 }
497 } else {
498 info!("Handling WebSocket connection locally for path: {}", path);
499 registry.record_ws_connection_closed(0.0, ""); handle_socket(socket).await;
504 return; }
506
507 let duration = connection_start.elapsed().as_secs_f64();
508 registry.record_ws_connection_closed(duration, status);
509 debug!(
510 "Proxied WebSocket connection closed (status: {}, duration: {:.2}s)",
511 status, duration
512 );
513}
514
515async fn handle_socket_with_handlers(
516 socket: WebSocket,
517 registry: std::sync::Arc<HandlerRegistry>,
518 path: String,
519) {
520 use std::time::Instant;
521
522 let metrics_registry = get_global_registry();
523 let connection_start = Instant::now();
524 metrics_registry.record_ws_connection_established();
525
526 let mut status = "normal";
527
528 let connection_id = uuid::Uuid::new_v4().to_string();
530
531 let handlers = registry.get_handlers(&path);
533 if handlers.is_empty() {
534 info!("No handlers found for path: {}, falling back to echo mode", path);
535 metrics_registry.record_ws_connection_closed(0.0, "");
536 handle_socket(socket).await;
537 return;
538 }
539
540 info!(
541 "Handling WebSocket connection with {} handler(s) for path: {}",
542 handlers.len(),
543 path
544 );
545
546 let room_manager = RoomManager::new();
548
549 let (mut socket_sender, mut socket_receiver) = socket.split();
551
552 let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
554
555 let mut ctx =
557 WsContext::new(connection_id.clone(), path.clone(), room_manager.clone(), message_tx);
558
559 for handler in &handlers {
561 if let Err(e) = handler.on_connect(&mut ctx).await {
562 error!("Handler on_connect error: {}", e);
563 status = "handler_error";
564 }
565 }
566
567 let send_task = tokio::spawn(async move {
569 while let Some(msg) = message_rx.recv().await {
570 if socket_sender.send(msg).await.is_err() {
571 break;
572 }
573 }
574 });
575
576 while let Some(msg) = socket_receiver.next().await {
578 match msg {
579 Ok(axum_msg) => {
580 metrics_registry.record_ws_message_received();
581
582 let ws_msg: WsMessage = axum_msg.into();
583
584 if matches!(ws_msg, WsMessage::Close) {
586 status = "client_close";
587 break;
588 }
589
590 for handler in &handlers {
592 if let Err(e) = handler.on_message(&mut ctx, ws_msg.clone()).await {
593 error!("Handler on_message error: {}", e);
594 status = "handler_error";
595 }
596 }
597
598 metrics_registry.record_ws_message_sent();
599 }
600 Err(e) => {
601 error!("WebSocket error: {}", e);
602 metrics_registry.record_ws_error();
603 status = "error";
604 break;
605 }
606 }
607 }
608
609 for handler in &handlers {
611 if let Err(e) = handler.on_disconnect(&mut ctx).await {
612 error!("Handler on_disconnect error: {}", e);
613 }
614 }
615
616 let _ = room_manager.leave_all(&connection_id).await;
618
619 send_task.abort();
621
622 let duration = connection_start.elapsed().as_secs_f64();
623 metrics_registry.record_ws_connection_closed(duration, status);
624 debug!(
625 "Handler-based WebSocket connection closed (status: {}, duration: {:.2}s)",
626 status, duration
627 );
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[test]
637 fn test_router_creation() {
638 let _router = router();
639 }
641
642 #[test]
643 fn test_router_with_latency_creation() {
644 let latency_profile = LatencyProfile::default();
645 let latency_injector = LatencyInjector::new(latency_profile, Default::default());
646 let _router = router_with_latency(latency_injector);
647 }
649
650 #[test]
651 fn test_router_with_proxy_creation() {
652 let config = mockforge_core::WsProxyConfig {
653 upstream_url: "ws://localhost:8080".to_string(),
654 ..Default::default()
655 };
656 let proxy_handler = WsProxyHandler::new(config);
657 let _router = router_with_proxy(proxy_handler);
658 }
660
661 #[test]
662 fn test_router_with_handlers_creation() {
663 let registry = std::sync::Arc::new(HandlerRegistry::new());
664 let _router = router_with_handlers(registry);
665 }
667
668 #[tokio::test]
669 async fn test_start_with_latency_config_none() {
670 let result = std::panic::catch_unwind(|| {
672 let _router = router();
673 });
674 assert!(result.is_ok());
675 }
676
677 #[tokio::test]
678 async fn test_start_with_latency_config_some() {
679 let latency_profile = LatencyProfile::default();
681 let latency_injector = LatencyInjector::new(latency_profile, Default::default());
682
683 let result = std::panic::catch_unwind(|| {
684 let _router = router_with_latency(latency_injector);
685 });
686 assert!(result.is_ok());
687 }
688
689 #[test]
692 fn test_expand_tokens_uuid() {
693 let text = "session-{{uuid}}";
694 let expanded = expand_tokens(text);
695 assert!(!expanded.contains("{{uuid}}"));
696 assert!(expanded.starts_with("session-"));
697 let uuid_part = &expanded[8..];
699 assert_eq!(uuid_part.len(), 36);
700 }
701
702 #[test]
703 fn test_expand_tokens_now() {
704 let text = "time: {{now}}";
705 let expanded = expand_tokens(text);
706 assert!(!expanded.contains("{{now}}"));
707 assert!(expanded.starts_with("time: "));
708 assert!(expanded.contains("T"));
710 }
711
712 #[test]
713 fn test_expand_tokens_now_plus_1m() {
714 let text = "expires: {{now+1m}}";
715 let expanded = expand_tokens(text);
716 assert!(!expanded.contains("{{now+1m}}"));
717 assert!(expanded.starts_with("expires: "));
718 }
719
720 #[test]
721 fn test_expand_tokens_now_plus_1h() {
722 let text = "expires: {{now+1h}}";
723 let expanded = expand_tokens(text);
724 assert!(!expanded.contains("{{now+1h}}"));
725 assert!(expanded.starts_with("expires: "));
726 }
727
728 #[test]
729 fn test_expand_tokens_randint() {
730 let text = "value: {{randInt 1 100}}";
731 let expanded = expand_tokens(text);
732 assert!(!expanded.contains("{{randInt"), "Token should be expanded");
733 assert!(expanded.starts_with("value: "));
734 }
736
737 #[test]
738 fn test_expand_tokens_randint_multiple() {
739 let text = "a: {{randInt 1 10}}, b: {{randInt 20 30}}";
740 let expanded = expand_tokens(text);
741 assert!(!expanded.contains("{{randInt"));
742 assert!(expanded.contains("a: "));
743 assert!(expanded.contains("b: "));
744 }
745
746 #[test]
747 fn test_expand_tokens_mixed() {
748 let text = "id: {{uuid}}, time: {{now}}, rand: {{randInt 1 10}}";
749 let expanded = expand_tokens(text);
750 assert!(!expanded.contains("{{uuid}}"));
751 assert!(!expanded.contains("{{now}}"));
752 assert!(!expanded.contains("{{randInt"));
753 }
754
755 #[test]
756 fn test_expand_tokens_no_tokens() {
757 let text = "plain text without tokens";
758 let expanded = expand_tokens(text);
759 assert_eq!(expanded, text);
760 }
761
762 #[test]
765 fn test_latency_profile_default() {
766 let profile = LatencyProfile::default();
767 let injector = LatencyInjector::new(profile, Default::default());
769 let _router = router_with_latency(injector);
770 }
771
772 #[test]
773 fn test_latency_profile_with_normal_distribution() {
774 let profile = LatencyProfile::with_normal_distribution(100, 25.0)
775 .with_min_ms(50)
776 .with_max_ms(200);
777 let injector = LatencyInjector::new(profile, Default::default());
778 let _router = router_with_latency(injector);
779 }
780
781 #[test]
784 fn test_ws_proxy_config_default() {
785 let config = mockforge_core::WsProxyConfig::default();
786 let _url = &config.upstream_url;
788 }
789
790 #[test]
791 fn test_ws_proxy_config_custom() {
792 let config = mockforge_core::WsProxyConfig {
793 upstream_url: "wss://api.example.com/ws".to_string(),
794 ..Default::default()
795 };
796 assert_eq!(config.upstream_url, "wss://api.example.com/ws");
797 }
798
799 #[test]
802 fn test_reexports_available() {
803 let _ = create_ws_connection_span("conn-123");
805
806 let _registry = HandlerRegistry::new();
808 let _pattern = MessagePattern::any();
809 }
810}