mockforge_ws/
lib.rs

1//! # MockForge WebSocket
2//!
3//! WebSocket mocking library for MockForge with replay, proxy, and AI-powered event generation.
4//!
5//! This crate provides comprehensive WebSocket mocking capabilities, including:
6//!
7//! - **Replay Mode**: Script and replay WebSocket message sequences
8//! - **Interactive Mode**: Dynamic responses based on client messages
9//! - **AI Event Streams**: Generate narrative-driven event sequences
10//! - **Proxy Mode**: Forward messages to real WebSocket backends
11//! - **JSONPath Matching**: Sophisticated message matching with JSONPath queries
12//!
13//! ## Overview
14//!
15//! MockForge WebSocket supports multiple operational modes:
16//!
17//! ### 1. Replay Mode
18//! Play back pre-recorded WebSocket interactions from JSONL files with template expansion.
19//!
20//! ### 2. Proxy Mode
21//! Forward WebSocket messages to upstream servers with optional message transformation.
22//!
23//! ### 3. AI Event Generation
24//! Generate realistic event streams using LLMs based on narrative descriptions.
25//!
26//! ## Quick Start
27//!
28//! ### Basic WebSocket Server
29//!
30//! ```rust,no_run
31//! use mockforge_ws::router;
32//! use std::net::SocketAddr;
33//!
34//! #[tokio::main]
35//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
36//!     // Create WebSocket router
37//!     let app = router();
38//!
39//!     // Start server
40//!     let addr: SocketAddr = "0.0.0.0:3001".parse()?;
41//!     let listener = tokio::net::TcpListener::bind(addr).await?;
42//!     axum::serve(listener, app).await?;
43//!
44//!     Ok(())
45//! }
46//! ```
47//!
48//! ### With Latency Simulation
49//!
50//! ```rust,no_run
51//! use mockforge_ws::router_with_latency;
52//! use mockforge_core::latency::{FaultConfig, LatencyInjector};
53//! use mockforge_core::LatencyProfile;
54//!
55//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
56//! let latency = LatencyProfile::with_normal_distribution(250, 75.0)
57//!     .with_min_ms(100)
58//!     .with_max_ms(500);
59//! let injector = LatencyInjector::new(latency, FaultConfig::default());
60//! let app = router_with_latency(injector);
61//! # Ok(())
62//! # }
63//! ```
64//!
65//! ### With Proxy Support
66//!
67//! ```rust,no_run
68//! use mockforge_ws::router_with_proxy;
69//! use mockforge_core::{WsProxyHandler, WsProxyConfig};
70//!
71//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
72//! let proxy_config = WsProxyConfig {
73//!     upstream_url: "wss://api.example.com/ws".to_string(),
74//!     ..Default::default()
75//! };
76//! let proxy = WsProxyHandler::new(proxy_config);
77//! let app = router_with_proxy(proxy);
78//! # Ok(())
79//! # }
80//! ```
81//!
82//! ### AI Event Generation
83//!
84//! Generate realistic event streams from narrative descriptions:
85//!
86//! ```rust,no_run
87//! use mockforge_ws::{AiEventGenerator, WebSocketAiConfig};
88//! use mockforge_data::replay_augmentation::{scenarios, ReplayMode};
89//!
90//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
91//! let ai_config = WebSocketAiConfig {
92//!     enabled: true,
93//!     replay: Some(scenarios::stock_market_scenario()),
94//!     max_events: Some(30),
95//!     event_rate: Some(1.5),
96//! };
97//!
98//! let generator = AiEventGenerator::new(ai_config.replay.clone().unwrap())?;
99//! let _events = generator; // use the generator with `stream_events` in your handler
100//! # Ok(())
101//! # }
102//! ```
103//!
104//! ## Replay File Format
105//!
106//! WebSocket replay files use JSON Lines (JSONL) format:
107//!
108//! ```json
109//! {"ts":0,"dir":"out","text":"HELLO {{uuid}}","waitFor":"^CLIENT_READY$"}
110//! {"ts":10,"dir":"out","text":"{\"type\":\"welcome\",\"sessionId\":\"{{uuid}}\"}"}
111//! {"ts":20,"dir":"out","text":"{\"data\":{{randInt 1 100}}}","waitFor":"^ACK$"}
112//! ```
113//!
114//! Fields:
115//! - `ts`: Timestamp in milliseconds
116//! - `dir`: Direction ("in" = received, "out" = sent)
117//! - `text`: Message content (supports template expansion)
118//! - `waitFor`: Optional regex/JSONPath pattern to wait for
119//!
120//! ## JSONPath Message Matching
121//!
122//! Match messages using JSONPath queries:
123//!
124//! ```json
125//! {"waitFor": "$.type", "text": "Type received"}
126//! {"waitFor": "$.user.id", "text": "User authenticated"}
127//! {"waitFor": "$.order.status", "text": "Order updated"}
128//! ```
129//!
130//! ## Key Modules
131//!
132//! - [`ai_event_generator`]: AI-powered event stream generation
133//! - [`ws_tracing`]: Distributed tracing integration
134//!
135//! ## Examples
136//!
137//! See the [examples directory](https://github.com/SaaSy-Solutions/mockforge/tree/main/examples)
138//! for complete working examples.
139//!
140//! ## Related Crates
141//!
142//! - [`mockforge-core`](https://docs.rs/mockforge-core): Core mocking functionality
143//! - [`mockforge-data`](https://docs.rs/mockforge-data): Synthetic data generation
144//!
145//! ## Documentation
146//!
147//! - [MockForge Book](https://docs.mockforge.dev/)
148//! - [WebSocket Mocking Guide](https://docs.mockforge.dev/user-guide/websocket-mocking.html)
149//! - [API Reference](https://docs.rs/mockforge-ws)
150
151pub mod ai_event_generator;
152pub mod handlers;
153pub mod ws_tracing;
154
155use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
156use axum::extract::{Path, State};
157use axum::{response::IntoResponse, routing::get, Router};
158use futures::sink::SinkExt;
159use futures::stream::StreamExt;
160use mockforge_core::{latency::LatencyInjector, LatencyProfile, WsProxyHandler};
161#[cfg(feature = "data-faker")]
162use mockforge_data::provider::register_core_faker_provider;
163use mockforge_observability::get_global_registry;
164use serde_json::Value;
165use tokio::fs;
166use tokio::time::{sleep, Duration};
167use tracing::*;
168
169// Re-export AI event generator utilities
170pub use ai_event_generator::{AiEventGenerator, WebSocketAiConfig};
171
172// Re-export tracing utilities
173pub use ws_tracing::{
174    create_ws_connection_span, create_ws_message_span, record_ws_connection_success,
175    record_ws_error, record_ws_message_success,
176};
177
178// Re-export handler utilities
179pub use handlers::{
180    HandlerError, HandlerRegistry, HandlerResult, MessagePattern, MessageRouter, PassthroughConfig,
181    PassthroughHandler, RoomManager, WsContext, WsHandler, WsMessage,
182};
183
184/// Build the WebSocket router (exposed for tests and embedding)
185pub fn router() -> Router {
186    #[cfg(feature = "data-faker")]
187    register_core_faker_provider();
188
189    Router::new().route("/ws", get(ws_handler_no_state))
190}
191
192/// Build the WebSocket router with latency injector state
193pub fn router_with_latency(latency_injector: LatencyInjector) -> Router {
194    #[cfg(feature = "data-faker")]
195    register_core_faker_provider();
196
197    Router::new()
198        .route("/ws", get(ws_handler_with_state))
199        .with_state(latency_injector)
200}
201
202/// Build the WebSocket router with proxy handler
203pub fn router_with_proxy(proxy_handler: WsProxyHandler) -> Router {
204    #[cfg(feature = "data-faker")]
205    register_core_faker_provider();
206
207    Router::new()
208        .route("/ws", get(ws_handler_with_proxy))
209        .route("/ws/{*path}", get(ws_handler_with_proxy_path))
210        .with_state(proxy_handler)
211}
212
213/// Build the WebSocket router with handler registry
214pub fn router_with_handlers(registry: std::sync::Arc<HandlerRegistry>) -> Router {
215    #[cfg(feature = "data-faker")]
216    register_core_faker_provider();
217
218    Router::new()
219        .route("/ws", get(ws_handler_with_registry))
220        .route("/ws/{*path}", get(ws_handler_with_registry_path))
221        .with_state(registry)
222}
223
224/// Start WebSocket server with latency simulation
225pub async fn start_with_latency(
226    port: u16,
227    latency: Option<LatencyProfile>,
228) -> Result<(), Box<dyn std::error::Error>> {
229    let latency_injector = latency.map(|profile| LatencyInjector::new(profile, Default::default()));
230    let router = if let Some(injector) = latency_injector {
231        router_with_latency(injector)
232    } else {
233        router()
234    };
235
236    let addr: std::net::SocketAddr = format!("127.0.0.1:{}", port).parse()?;
237    info!("WebSocket server listening on {}", addr);
238
239    let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
240        format!(
241            "Failed to bind WebSocket server to port {}: {}\n\
242             Hint: The port may already be in use. Try using a different port with --ws-port or check if another process is using this port with: lsof -i :{} or netstat -tulpn | grep {}",
243            port, e, port, port
244        )
245    })?;
246
247    axum::serve(listener, router).await?;
248    Ok(())
249}
250
251// WebSocket handlers
252async fn ws_handler_no_state(ws: WebSocketUpgrade) -> impl IntoResponse {
253    ws.on_upgrade(handle_socket)
254}
255
256async fn ws_handler_with_state(
257    ws: WebSocketUpgrade,
258    axum::extract::State(_latency): axum::extract::State<LatencyInjector>,
259) -> impl IntoResponse {
260    ws.on_upgrade(handle_socket)
261}
262
263async fn ws_handler_with_proxy(
264    ws: WebSocketUpgrade,
265    State(proxy): State<WsProxyHandler>,
266) -> impl IntoResponse {
267    ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, "/ws".to_string()))
268}
269
270async fn ws_handler_with_proxy_path(
271    Path(path): Path<String>,
272    ws: WebSocketUpgrade,
273    State(proxy): State<WsProxyHandler>,
274) -> impl IntoResponse {
275    let full_path = format!("/ws/{}", path);
276    ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, full_path))
277}
278
279async fn ws_handler_with_registry(
280    ws: WebSocketUpgrade,
281    State(registry): State<std::sync::Arc<HandlerRegistry>>,
282) -> impl IntoResponse {
283    ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, "/ws".to_string()))
284}
285
286async fn ws_handler_with_registry_path(
287    Path(path): Path<String>,
288    ws: WebSocketUpgrade,
289    State(registry): State<std::sync::Arc<HandlerRegistry>>,
290) -> impl IntoResponse {
291    let full_path = format!("/ws/{}", path);
292    ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, full_path))
293}
294
295async fn handle_socket(mut socket: WebSocket) {
296    use std::time::Instant;
297
298    // Track WebSocket connection
299    let registry = get_global_registry();
300    let connection_start = Instant::now();
301    registry.record_ws_connection_established();
302    debug!("WebSocket connection established, tracking metrics");
303
304    // Track connection status (for metrics reporting)
305    let mut status = "normal";
306
307    // Check if replay mode is enabled
308    if let Ok(replay_file) = std::env::var("MOCKFORGE_WS_REPLAY_FILE") {
309        info!("WebSocket replay mode enabled with file: {}", replay_file);
310        handle_socket_with_replay(socket, &replay_file).await;
311    } else {
312        // Normal echo mode
313        while let Some(msg) = socket.recv().await {
314            match msg {
315                Ok(Message::Text(text)) => {
316                    registry.record_ws_message_received();
317
318                    // Echo the message back with "echo: " prefix
319                    let response = format!("echo: {}", text);
320                    if socket.send(Message::Text(response.into())).await.is_err() {
321                        status = "send_error";
322                        break;
323                    }
324                    registry.record_ws_message_sent();
325                }
326                Ok(Message::Close(_)) => {
327                    status = "client_close";
328                    break;
329                }
330                Err(e) => {
331                    error!("WebSocket error: {}", e);
332                    registry.record_ws_error();
333                    status = "error";
334                    break;
335                }
336                _ => {}
337            }
338        }
339    }
340
341    // Connection closed - record duration
342    let duration = connection_start.elapsed().as_secs_f64();
343    registry.record_ws_connection_closed(duration, status);
344    debug!("WebSocket connection closed (status: {}, duration: {:.2}s)", status, duration);
345}
346
347async fn handle_socket_with_replay(mut socket: WebSocket, replay_file: &str) {
348    let _registry = get_global_registry(); // Available for future message tracking
349
350    // Read the replay file
351    let content = match fs::read_to_string(replay_file).await {
352        Ok(content) => content,
353        Err(e) => {
354            error!("Failed to read replay file {}: {}", replay_file, e);
355            return;
356        }
357    };
358
359    // Parse JSONL file
360    let mut replay_entries = Vec::new();
361    for line in content.lines() {
362        if let Ok(entry) = serde_json::from_str::<Value>(line) {
363            replay_entries.push(entry);
364        }
365    }
366
367    info!("Loaded {} replay entries", replay_entries.len());
368
369    // Process replay entries
370    for entry in replay_entries {
371        // Check if we need to wait for a specific message
372        if let Some(wait_for) = entry.get("waitFor") {
373            if let Some(wait_pattern) = wait_for.as_str() {
374                info!("Waiting for pattern: {}", wait_pattern);
375                // Wait for matching message from client
376                let mut found = false;
377                while let Some(msg) = socket.recv().await {
378                    if let Ok(Message::Text(text)) = msg {
379                        if text.contains(wait_pattern) || wait_pattern == "^CLIENT_READY$" {
380                            found = true;
381                            break;
382                        }
383                    }
384                }
385                if !found {
386                    break;
387                }
388            }
389        }
390
391        // Get the message text
392        if let Some(text) = entry.get("text").and_then(|v| v.as_str()) {
393            // Expand tokens if enabled
394            let expanded_text = if std::env::var("MOCKFORGE_RESPONSE_TEMPLATE_EXPAND")
395                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
396                .unwrap_or(false)
397            {
398                expand_tokens(text)
399            } else {
400                text.to_string()
401            };
402
403            info!("Sending replay message: {}", expanded_text);
404            if socket.send(Message::Text(expanded_text.into())).await.is_err() {
405                break;
406            }
407        }
408
409        // Wait for the specified time
410        if let Some(ts) = entry.get("ts").and_then(|v| v.as_u64()) {
411            sleep(Duration::from_millis(ts * 10)).await; // Convert to milliseconds
412        }
413    }
414}
415
416fn expand_tokens(text: &str) -> String {
417    let mut result = text.to_string();
418
419    // Expand {{uuid}}
420    result = result.replace("{{uuid}}", &uuid::Uuid::new_v4().to_string());
421
422    // Expand {{now}}
423    result = result.replace("{{now}}", &chrono::Utc::now().to_rfc3339());
424
425    // Expand {{now+1m}} (add 1 minute)
426    if result.contains("{{now+1m}}") {
427        let now_plus_1m = chrono::Utc::now() + chrono::Duration::minutes(1);
428        result = result.replace("{{now+1m}}", &now_plus_1m.to_rfc3339());
429    }
430
431    // Expand {{now+1h}} (add 1 hour)
432    if result.contains("{{now+1h}}") {
433        let now_plus_1h = chrono::Utc::now() + chrono::Duration::hours(1);
434        result = result.replace("{{now+1h}}", &now_plus_1h.to_rfc3339());
435    }
436
437    // Expand {{randInt min max}}
438    while result.contains("{{randInt") {
439        if let Some(start) = result.find("{{randInt") {
440            if let Some(end) = result[start..].find("}}") {
441                let full_match = &result[start..start + end + 2];
442                let content = &result[start + 9..start + end]; // Skip "{{randInt"
443
444                if let Some(space_pos) = content.find(' ') {
445                    let min_str = &content[..space_pos];
446                    let max_str = &content[space_pos + 1..];
447
448                    if let (Ok(min), Ok(max)) = (min_str.parse::<i32>(), max_str.parse::<i32>()) {
449                        let random_value = fastrand::i32(min..=max);
450                        result = result.replace(full_match, &random_value.to_string());
451                    } else {
452                        result = result.replace(full_match, "0");
453                    }
454                } else {
455                    result = result.replace(full_match, "0");
456                }
457            } else {
458                break;
459            }
460        } else {
461            break;
462        }
463    }
464
465    result
466}
467
468async fn handle_socket_with_proxy(socket: WebSocket, proxy: WsProxyHandler, path: String) {
469    use std::time::Instant;
470
471    let registry = get_global_registry();
472    let connection_start = Instant::now();
473    registry.record_ws_connection_established();
474
475    let mut status = "normal";
476
477    // Check if this connection should be proxied
478    if proxy.config.should_proxy(&path) {
479        info!("Proxying WebSocket connection for path: {}", path);
480        if let Err(e) = proxy.proxy_connection(&path, socket).await {
481            error!("Failed to proxy WebSocket connection: {}", e);
482            registry.record_ws_error();
483            status = "proxy_error";
484        }
485    } else {
486        info!("Handling WebSocket connection locally for path: {}", path);
487        // Handle locally by echoing messages
488        // Note: handle_socket already tracks its own connection metrics,
489        // so we need to avoid double-counting
490        registry.record_ws_connection_closed(0.0, ""); // Decrement the one we just added
491        handle_socket(socket).await;
492        return; // Early return to avoid double-tracking
493    }
494
495    let duration = connection_start.elapsed().as_secs_f64();
496    registry.record_ws_connection_closed(duration, status);
497    debug!(
498        "Proxied WebSocket connection closed (status: {}, duration: {:.2}s)",
499        status, duration
500    );
501}
502
503async fn handle_socket_with_handlers(
504    socket: WebSocket,
505    registry: std::sync::Arc<HandlerRegistry>,
506    path: String,
507) {
508    use std::time::Instant;
509
510    let metrics_registry = get_global_registry();
511    let connection_start = Instant::now();
512    metrics_registry.record_ws_connection_established();
513
514    let mut status = "normal";
515
516    // Generate unique connection ID
517    let connection_id = uuid::Uuid::new_v4().to_string();
518
519    // Get handlers for this path
520    let handlers = registry.get_handlers(&path);
521    if handlers.is_empty() {
522        info!("No handlers found for path: {}, falling back to echo mode", path);
523        metrics_registry.record_ws_connection_closed(0.0, "");
524        handle_socket(socket).await;
525        return;
526    }
527
528    info!(
529        "Handling WebSocket connection with {} handler(s) for path: {}",
530        handlers.len(),
531        path
532    );
533
534    // Create room manager
535    let room_manager = RoomManager::new();
536
537    // Split socket for concurrent send/receive
538    let (mut socket_sender, mut socket_receiver) = socket.split();
539
540    // Create message channel for handlers to send messages
541    let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
542
543    // Create context
544    let mut ctx =
545        WsContext::new(connection_id.clone(), path.clone(), room_manager.clone(), message_tx);
546
547    // Call on_connect for all handlers
548    for handler in &handlers {
549        if let Err(e) = handler.on_connect(&mut ctx).await {
550            error!("Handler on_connect error: {}", e);
551            status = "handler_error";
552        }
553    }
554
555    // Spawn task to send messages from handlers to the socket
556    let send_task = tokio::spawn(async move {
557        while let Some(msg) = message_rx.recv().await {
558            if socket_sender.send(msg).await.is_err() {
559                break;
560            }
561        }
562    });
563
564    // Handle incoming messages
565    while let Some(msg) = socket_receiver.next().await {
566        match msg {
567            Ok(axum_msg) => {
568                metrics_registry.record_ws_message_received();
569
570                let ws_msg: WsMessage = axum_msg.into();
571
572                // Check for close message
573                if matches!(ws_msg, WsMessage::Close) {
574                    status = "client_close";
575                    break;
576                }
577
578                // Pass message through all handlers
579                for handler in &handlers {
580                    if let Err(e) = handler.on_message(&mut ctx, ws_msg.clone()).await {
581                        error!("Handler on_message error: {}", e);
582                        status = "handler_error";
583                    }
584                }
585
586                metrics_registry.record_ws_message_sent();
587            }
588            Err(e) => {
589                error!("WebSocket error: {}", e);
590                metrics_registry.record_ws_error();
591                status = "error";
592                break;
593            }
594        }
595    }
596
597    // Call on_disconnect for all handlers
598    for handler in &handlers {
599        if let Err(e) = handler.on_disconnect(&mut ctx).await {
600            error!("Handler on_disconnect error: {}", e);
601        }
602    }
603
604    // Clean up room memberships
605    let _ = room_manager.leave_all(&connection_id).await;
606
607    // Abort send task
608    send_task.abort();
609
610    let duration = connection_start.elapsed().as_secs_f64();
611    metrics_registry.record_ws_connection_closed(duration, status);
612    debug!(
613        "Handler-based WebSocket connection closed (status: {}, duration: {:.2}s)",
614        status, duration
615    );
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    #[test]
623    fn test_router_creation() {
624        let _router = router();
625        // Router should be created successfully
626    }
627
628    #[test]
629    fn test_router_with_latency_creation() {
630        let latency_profile = LatencyProfile::default();
631        let latency_injector = LatencyInjector::new(latency_profile, Default::default());
632        let _router = router_with_latency(latency_injector);
633        // Router should be created successfully
634    }
635
636    #[test]
637    fn test_router_with_proxy_creation() {
638        let config = mockforge_core::WsProxyConfig {
639            upstream_url: "ws://localhost:8080".to_string(),
640            ..Default::default()
641        };
642        let proxy_handler = WsProxyHandler::new(config);
643        let _router = router_with_proxy(proxy_handler);
644        // Router should be created successfully
645    }
646
647    #[tokio::test]
648    async fn test_start_with_latency_config_none() {
649        // Test that we can create the router without latency
650        let result = std::panic::catch_unwind(|| {
651            let _router = router();
652        });
653        assert!(result.is_ok());
654    }
655
656    #[tokio::test]
657    async fn test_start_with_latency_config_some() {
658        // Test that we can create the router with latency
659        let latency_profile = LatencyProfile::default();
660        let latency_injector = LatencyInjector::new(latency_profile, Default::default());
661
662        let result = std::panic::catch_unwind(|| {
663            let _router = router_with_latency(latency_injector);
664        });
665        assert!(result.is_ok());
666    }
667}