1pub mod ai_event_generator;
152pub mod handlers;
153pub mod ws_tracing;
154
155use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
156use axum::extract::{Path, State};
157use axum::{response::IntoResponse, routing::get, Router};
158use futures::sink::SinkExt;
159use futures::stream::StreamExt;
160use mockforge_core::{latency::LatencyInjector, LatencyProfile, WsProxyHandler};
161#[cfg(feature = "data-faker")]
162use mockforge_data::provider::register_core_faker_provider;
163use mockforge_observability::get_global_registry;
164use serde_json::Value;
165use tokio::fs;
166use tokio::time::{sleep, Duration};
167use tracing::*;
168
169pub use ai_event_generator::{AiEventGenerator, WebSocketAiConfig};
171
172pub use ws_tracing::{
174 create_ws_connection_span, create_ws_message_span, record_ws_connection_success,
175 record_ws_error, record_ws_message_success,
176};
177
178pub use handlers::{
180 HandlerError, HandlerRegistry, HandlerResult, MessagePattern, MessageRouter, PassthroughConfig,
181 PassthroughHandler, RoomManager, WsContext, WsHandler, WsMessage,
182};
183
184pub fn router() -> Router {
186 #[cfg(feature = "data-faker")]
187 register_core_faker_provider();
188
189 Router::new().route("/ws", get(ws_handler_no_state))
190}
191
192pub fn router_with_latency(latency_injector: LatencyInjector) -> Router {
194 #[cfg(feature = "data-faker")]
195 register_core_faker_provider();
196
197 Router::new()
198 .route("/ws", get(ws_handler_with_state))
199 .with_state(latency_injector)
200}
201
202pub fn router_with_proxy(proxy_handler: WsProxyHandler) -> Router {
204 #[cfg(feature = "data-faker")]
205 register_core_faker_provider();
206
207 Router::new()
208 .route("/ws", get(ws_handler_with_proxy))
209 .route("/ws/{*path}", get(ws_handler_with_proxy_path))
210 .with_state(proxy_handler)
211}
212
213pub fn router_with_handlers(registry: std::sync::Arc<HandlerRegistry>) -> Router {
215 #[cfg(feature = "data-faker")]
216 register_core_faker_provider();
217
218 Router::new()
219 .route("/ws", get(ws_handler_with_registry))
220 .route("/ws/{*path}", get(ws_handler_with_registry_path))
221 .with_state(registry)
222}
223
224pub async fn start_with_latency(
226 port: u16,
227 latency: Option<LatencyProfile>,
228) -> Result<(), Box<dyn std::error::Error>> {
229 start_with_latency_and_host(port, "0.0.0.0", latency).await
230}
231
232pub async fn start_with_latency_and_host(
234 port: u16,
235 host: &str,
236 latency: Option<LatencyProfile>,
237) -> Result<(), Box<dyn std::error::Error>> {
238 let latency_injector = latency.map(|profile| LatencyInjector::new(profile, Default::default()));
239 let router = if let Some(injector) = latency_injector {
240 router_with_latency(injector)
241 } else {
242 router()
243 };
244
245 let addr: std::net::SocketAddr = format!("{}:{}", host, port).parse()?;
246 info!("WebSocket server listening on {}", addr);
247
248 let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
249 format!(
250 "Failed to bind WebSocket server to port {}: {}\n\
251 Hint: The port may already be in use. Try using a different port with --ws-port or check if another process is using this port with: lsof -i :{} or netstat -tulpn | grep {}",
252 port, e, port, port
253 )
254 })?;
255
256 axum::serve(listener, router).await?;
257 Ok(())
258}
259
260async fn ws_handler_no_state(ws: WebSocketUpgrade) -> impl IntoResponse {
262 ws.on_upgrade(handle_socket)
263}
264
265async fn ws_handler_with_state(
266 ws: WebSocketUpgrade,
267 axum::extract::State(_latency): axum::extract::State<LatencyInjector>,
268) -> impl IntoResponse {
269 ws.on_upgrade(handle_socket)
270}
271
272async fn ws_handler_with_proxy(
273 ws: WebSocketUpgrade,
274 State(proxy): State<WsProxyHandler>,
275) -> impl IntoResponse {
276 ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, "/ws".to_string()))
277}
278
279async fn ws_handler_with_proxy_path(
280 Path(path): Path<String>,
281 ws: WebSocketUpgrade,
282 State(proxy): State<WsProxyHandler>,
283) -> impl IntoResponse {
284 let full_path = format!("/ws/{}", path);
285 ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, full_path))
286}
287
288async fn ws_handler_with_registry(
289 ws: WebSocketUpgrade,
290 State(registry): State<std::sync::Arc<HandlerRegistry>>,
291) -> impl IntoResponse {
292 ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, "/ws".to_string()))
293}
294
295async fn ws_handler_with_registry_path(
296 Path(path): Path<String>,
297 ws: WebSocketUpgrade,
298 State(registry): State<std::sync::Arc<HandlerRegistry>>,
299) -> impl IntoResponse {
300 let full_path = format!("/ws/{}", path);
301 ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, full_path))
302}
303
304async fn handle_socket(mut socket: WebSocket) {
305 use std::time::Instant;
306
307 let registry = get_global_registry();
309 let connection_start = Instant::now();
310 registry.record_ws_connection_established();
311 debug!("WebSocket connection established, tracking metrics");
312
313 let mut status = "normal";
315
316 if let Ok(replay_file) = std::env::var("MOCKFORGE_WS_REPLAY_FILE") {
318 info!("WebSocket replay mode enabled with file: {}", replay_file);
319 handle_socket_with_replay(socket, &replay_file).await;
320 } else {
321 while let Some(msg) = socket.recv().await {
323 match msg {
324 Ok(Message::Text(text)) => {
325 registry.record_ws_message_received();
326
327 let response = format!("echo: {}", text);
329 if socket.send(Message::Text(response.into())).await.is_err() {
330 status = "send_error";
331 break;
332 }
333 registry.record_ws_message_sent();
334 }
335 Ok(Message::Close(_)) => {
336 status = "client_close";
337 break;
338 }
339 Err(e) => {
340 error!("WebSocket error: {}", e);
341 registry.record_ws_error();
342 status = "error";
343 break;
344 }
345 _ => {}
346 }
347 }
348 }
349
350 let duration = connection_start.elapsed().as_secs_f64();
352 registry.record_ws_connection_closed(duration, status);
353 debug!("WebSocket connection closed (status: {}, duration: {:.2}s)", status, duration);
354}
355
356async fn handle_socket_with_replay(mut socket: WebSocket, replay_file: &str) {
357 let _registry = get_global_registry(); let content = match fs::read_to_string(replay_file).await {
361 Ok(content) => content,
362 Err(e) => {
363 error!("Failed to read replay file {}: {}", replay_file, e);
364 return;
365 }
366 };
367
368 let mut replay_entries = Vec::new();
370 for line in content.lines() {
371 if let Ok(entry) = serde_json::from_str::<Value>(line) {
372 replay_entries.push(entry);
373 }
374 }
375
376 info!("Loaded {} replay entries", replay_entries.len());
377
378 for entry in replay_entries {
380 if let Some(wait_for) = entry.get("waitFor") {
382 if let Some(wait_pattern) = wait_for.as_str() {
383 info!("Waiting for pattern: {}", wait_pattern);
384 let mut found = false;
386 while let Some(msg) = socket.recv().await {
387 if let Ok(Message::Text(text)) = msg {
388 if text.contains(wait_pattern) || wait_pattern == "^CLIENT_READY$" {
389 found = true;
390 break;
391 }
392 }
393 }
394 if !found {
395 break;
396 }
397 }
398 }
399
400 if let Some(text) = entry.get("text").and_then(|v| v.as_str()) {
402 let expanded_text = if std::env::var("MOCKFORGE_RESPONSE_TEMPLATE_EXPAND")
404 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
405 .unwrap_or(false)
406 {
407 expand_tokens(text)
408 } else {
409 text.to_string()
410 };
411
412 info!("Sending replay message: {}", expanded_text);
413 if socket.send(Message::Text(expanded_text.into())).await.is_err() {
414 break;
415 }
416 }
417
418 if let Some(ts) = entry.get("ts").and_then(|v| v.as_u64()) {
420 sleep(Duration::from_millis(ts * 10)).await; }
422 }
423}
424
425fn expand_tokens(text: &str) -> String {
426 let mut result = text.to_string();
427
428 result = result.replace("{{uuid}}", &uuid::Uuid::new_v4().to_string());
430
431 result = result.replace("{{now}}", &chrono::Utc::now().to_rfc3339());
433
434 if result.contains("{{now+1m}}") {
436 let now_plus_1m = chrono::Utc::now() + chrono::Duration::minutes(1);
437 result = result.replace("{{now+1m}}", &now_plus_1m.to_rfc3339());
438 }
439
440 if result.contains("{{now+1h}}") {
442 let now_plus_1h = chrono::Utc::now() + chrono::Duration::hours(1);
443 result = result.replace("{{now+1h}}", &now_plus_1h.to_rfc3339());
444 }
445
446 while result.contains("{{randInt") {
448 if let Some(start) = result.find("{{randInt") {
449 if let Some(end) = result[start..].find("}}") {
450 let full_match = &result[start..start + end + 2];
451 let content = &result[start + 9..start + end]; if let Some(space_pos) = content.find(' ') {
454 let min_str = &content[..space_pos];
455 let max_str = &content[space_pos + 1..];
456
457 if let (Ok(min), Ok(max)) = (min_str.parse::<i32>(), max_str.parse::<i32>()) {
458 let random_value = fastrand::i32(min..=max);
459 result = result.replace(full_match, &random_value.to_string());
460 } else {
461 result = result.replace(full_match, "0");
462 }
463 } else {
464 result = result.replace(full_match, "0");
465 }
466 } else {
467 break;
468 }
469 } else {
470 break;
471 }
472 }
473
474 result
475}
476
477async fn handle_socket_with_proxy(socket: WebSocket, proxy: WsProxyHandler, path: String) {
478 use std::time::Instant;
479
480 let registry = get_global_registry();
481 let connection_start = Instant::now();
482 registry.record_ws_connection_established();
483
484 let mut status = "normal";
485
486 if proxy.config.should_proxy(&path) {
488 info!("Proxying WebSocket connection for path: {}", path);
489 if let Err(e) = proxy.proxy_connection(&path, socket).await {
490 error!("Failed to proxy WebSocket connection: {}", e);
491 registry.record_ws_error();
492 status = "proxy_error";
493 }
494 } else {
495 info!("Handling WebSocket connection locally for path: {}", path);
496 registry.record_ws_connection_closed(0.0, ""); handle_socket(socket).await;
501 return; }
503
504 let duration = connection_start.elapsed().as_secs_f64();
505 registry.record_ws_connection_closed(duration, status);
506 debug!(
507 "Proxied WebSocket connection closed (status: {}, duration: {:.2}s)",
508 status, duration
509 );
510}
511
512async fn handle_socket_with_handlers(
513 socket: WebSocket,
514 registry: std::sync::Arc<HandlerRegistry>,
515 path: String,
516) {
517 use std::time::Instant;
518
519 let metrics_registry = get_global_registry();
520 let connection_start = Instant::now();
521 metrics_registry.record_ws_connection_established();
522
523 let mut status = "normal";
524
525 let connection_id = uuid::Uuid::new_v4().to_string();
527
528 let handlers = registry.get_handlers(&path);
530 if handlers.is_empty() {
531 info!("No handlers found for path: {}, falling back to echo mode", path);
532 metrics_registry.record_ws_connection_closed(0.0, "");
533 handle_socket(socket).await;
534 return;
535 }
536
537 info!(
538 "Handling WebSocket connection with {} handler(s) for path: {}",
539 handlers.len(),
540 path
541 );
542
543 let room_manager = RoomManager::new();
545
546 let (mut socket_sender, mut socket_receiver) = socket.split();
548
549 let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
551
552 let mut ctx =
554 WsContext::new(connection_id.clone(), path.clone(), room_manager.clone(), message_tx);
555
556 for handler in &handlers {
558 if let Err(e) = handler.on_connect(&mut ctx).await {
559 error!("Handler on_connect error: {}", e);
560 status = "handler_error";
561 }
562 }
563
564 let send_task = tokio::spawn(async move {
566 while let Some(msg) = message_rx.recv().await {
567 if socket_sender.send(msg).await.is_err() {
568 break;
569 }
570 }
571 });
572
573 while let Some(msg) = socket_receiver.next().await {
575 match msg {
576 Ok(axum_msg) => {
577 metrics_registry.record_ws_message_received();
578
579 let ws_msg: WsMessage = axum_msg.into();
580
581 if matches!(ws_msg, WsMessage::Close) {
583 status = "client_close";
584 break;
585 }
586
587 for handler in &handlers {
589 if let Err(e) = handler.on_message(&mut ctx, ws_msg.clone()).await {
590 error!("Handler on_message error: {}", e);
591 status = "handler_error";
592 }
593 }
594
595 metrics_registry.record_ws_message_sent();
596 }
597 Err(e) => {
598 error!("WebSocket error: {}", e);
599 metrics_registry.record_ws_error();
600 status = "error";
601 break;
602 }
603 }
604 }
605
606 for handler in &handlers {
608 if let Err(e) = handler.on_disconnect(&mut ctx).await {
609 error!("Handler on_disconnect error: {}", e);
610 }
611 }
612
613 let _ = room_manager.leave_all(&connection_id).await;
615
616 send_task.abort();
618
619 let duration = connection_start.elapsed().as_secs_f64();
620 metrics_registry.record_ws_connection_closed(duration, status);
621 debug!(
622 "Handler-based WebSocket connection closed (status: {}, duration: {:.2}s)",
623 status, duration
624 );
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630
631 #[test]
632 fn test_router_creation() {
633 let _router = router();
634 }
636
637 #[test]
638 fn test_router_with_latency_creation() {
639 let latency_profile = LatencyProfile::default();
640 let latency_injector = LatencyInjector::new(latency_profile, Default::default());
641 let _router = router_with_latency(latency_injector);
642 }
644
645 #[test]
646 fn test_router_with_proxy_creation() {
647 let config = mockforge_core::WsProxyConfig {
648 upstream_url: "ws://localhost:8080".to_string(),
649 ..Default::default()
650 };
651 let proxy_handler = WsProxyHandler::new(config);
652 let _router = router_with_proxy(proxy_handler);
653 }
655
656 #[tokio::test]
657 async fn test_start_with_latency_config_none() {
658 let result = std::panic::catch_unwind(|| {
660 let _router = router();
661 });
662 assert!(result.is_ok());
663 }
664
665 #[tokio::test]
666 async fn test_start_with_latency_config_some() {
667 let latency_profile = LatencyProfile::default();
669 let latency_injector = LatencyInjector::new(latency_profile, Default::default());
670
671 let result = std::panic::catch_unwind(|| {
672 let _router = router_with_latency(latency_injector);
673 });
674 assert!(result.is_ok());
675 }
676}