mockforge_ws/
lib.rs

1//! # MockForge WebSocket
2//!
3//! WebSocket mocking library for MockForge with replay, proxy, and AI-powered event generation.
4//!
5//! This crate provides comprehensive WebSocket mocking capabilities, including:
6//!
7//! - **Replay Mode**: Script and replay WebSocket message sequences
8//! - **Interactive Mode**: Dynamic responses based on client messages
9//! - **AI Event Streams**: Generate narrative-driven event sequences
10//! - **Proxy Mode**: Forward messages to real WebSocket backends
11//! - **JSONPath Matching**: Sophisticated message matching with JSONPath queries
12//!
13//! ## Overview
14//!
15//! MockForge WebSocket supports multiple operational modes:
16//!
17//! ### 1. Replay Mode
18//! Play back pre-recorded WebSocket interactions from JSONL files with template expansion.
19//!
20//! ### 2. Proxy Mode
21//! Forward WebSocket messages to upstream servers with optional message transformation.
22//!
23//! ### 3. AI Event Generation
24//! Generate realistic event streams using LLMs based on narrative descriptions.
25//!
26//! ## Quick Start
27//!
28//! ### Basic WebSocket Server
29//!
30//! ```rust,no_run
31//! use mockforge_ws::router;
32//! use std::net::SocketAddr;
33//!
34//! #[tokio::main]
35//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
36//!     // Create WebSocket router
37//!     let app = router();
38//!
39//!     // Start server
40//!     let addr: SocketAddr = "0.0.0.0:3001".parse()?;
41//!     let listener = tokio::net::TcpListener::bind(addr).await?;
42//!     axum::serve(listener, app).await?;
43//!
44//!     Ok(())
45//! }
46//! ```
47//!
48//! ### With Latency Simulation
49//!
50//! ```rust,no_run
51//! use mockforge_ws::router_with_latency;
52//! use mockforge_core::latency::{FaultConfig, LatencyInjector};
53//! use mockforge_core::LatencyProfile;
54//!
55//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
56//! let latency = LatencyProfile::with_normal_distribution(250, 75.0)
57//!     .with_min_ms(100)
58//!     .with_max_ms(500);
59//! let injector = LatencyInjector::new(latency, FaultConfig::default());
60//! let app = router_with_latency(injector);
61//! # Ok(())
62//! # }
63//! ```
64//!
65//! ### With Proxy Support
66//!
67//! ```rust,no_run
68//! use mockforge_ws::router_with_proxy;
69//! use mockforge_core::{WsProxyHandler, WsProxyConfig};
70//!
71//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
72//! let proxy_config = WsProxyConfig {
73//!     upstream_url: "wss://api.example.com/ws".to_string(),
74//!     ..Default::default()
75//! };
76//! let proxy = WsProxyHandler::new(proxy_config);
77//! let app = router_with_proxy(proxy);
78//! # Ok(())
79//! # }
80//! ```
81//!
82//! ### AI Event Generation
83//!
84//! Generate realistic event streams from narrative descriptions:
85//!
86//! ```rust,no_run
87//! use mockforge_ws::{AiEventGenerator, WebSocketAiConfig};
88//! use mockforge_data::replay_augmentation::{scenarios, ReplayMode};
89//!
90//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
91//! let ai_config = WebSocketAiConfig {
92//!     enabled: true,
93//!     replay: Some(scenarios::stock_market_scenario()),
94//!     max_events: Some(30),
95//!     event_rate: Some(1.5),
96//! };
97//!
98//! let generator = AiEventGenerator::new(ai_config.replay.clone().unwrap())?;
99//! let _events = generator; // use the generator with `stream_events` in your handler
100//! # Ok(())
101//! # }
102//! ```
103//!
104//! ## Replay File Format
105//!
106//! WebSocket replay files use JSON Lines (JSONL) format:
107//!
108//! ```json
109//! {"ts":0,"dir":"out","text":"HELLO {{uuid}}","waitFor":"^CLIENT_READY$"}
110//! {"ts":10,"dir":"out","text":"{\"type\":\"welcome\",\"sessionId\":\"{{uuid}}\"}"}
111//! {"ts":20,"dir":"out","text":"{\"data\":{{randInt 1 100}}}","waitFor":"^ACK$"}
112//! ```
113//!
114//! Fields:
115//! - `ts`: Timestamp in milliseconds
116//! - `dir`: Direction ("in" = received, "out" = sent)
117//! - `text`: Message content (supports template expansion)
118//! - `waitFor`: Optional regex/JSONPath pattern to wait for
119//!
120//! ## JSONPath Message Matching
121//!
122//! Match messages using JSONPath queries:
123//!
124//! ```json
125//! {"waitFor": "$.type", "text": "Type received"}
126//! {"waitFor": "$.user.id", "text": "User authenticated"}
127//! {"waitFor": "$.order.status", "text": "Order updated"}
128//! ```
129//!
130//! ## Key Modules
131//!
132//! - [`ai_event_generator`]: AI-powered event stream generation
133//! - [`ws_tracing`]: Distributed tracing integration
134//!
135//! ## Examples
136//!
137//! See the [examples directory](https://github.com/SaaSy-Solutions/mockforge/tree/main/examples)
138//! for complete working examples.
139//!
140//! ## Related Crates
141//!
142//! - [`mockforge-core`](https://docs.rs/mockforge-core): Core mocking functionality
143//! - [`mockforge-data`](https://docs.rs/mockforge-data): Synthetic data generation
144//!
145//! ## Documentation
146//!
147//! - [MockForge Book](https://docs.mockforge.dev/)
148//! - [WebSocket Mocking Guide](https://docs.mockforge.dev/user-guide/websocket-mocking.html)
149//! - [API Reference](https://docs.rs/mockforge-ws)
150
151pub mod ai_event_generator;
152pub mod handlers;
153pub mod ws_tracing;
154
155use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
156use axum::extract::{Path, State};
157use axum::{response::IntoResponse, routing::get, Router};
158use futures::sink::SinkExt;
159use futures::stream::StreamExt;
160use mockforge_core::{latency::LatencyInjector, LatencyProfile, WsProxyHandler};
161#[cfg(feature = "data-faker")]
162use mockforge_data::provider::register_core_faker_provider;
163use mockforge_observability::get_global_registry;
164use serde_json::Value;
165use tokio::fs;
166use tokio::time::{sleep, Duration};
167use tracing::*;
168
169// Re-export AI event generator utilities
170pub use ai_event_generator::{AiEventGenerator, WebSocketAiConfig};
171
172// Re-export tracing utilities
173pub use ws_tracing::{
174    create_ws_connection_span, create_ws_message_span, record_ws_connection_success,
175    record_ws_error, record_ws_message_success,
176};
177
178// Re-export handler utilities
179pub use handlers::{
180    HandlerError, HandlerRegistry, HandlerResult, MessagePattern, MessageRouter, PassthroughConfig,
181    PassthroughHandler, RoomManager, WsContext, WsHandler, WsMessage,
182};
183
184/// Build the WebSocket router (exposed for tests and embedding)
185pub fn router() -> Router {
186    #[cfg(feature = "data-faker")]
187    register_core_faker_provider();
188
189    Router::new().route("/ws", get(ws_handler_no_state))
190}
191
192/// Build the WebSocket router with latency injector state
193pub fn router_with_latency(latency_injector: LatencyInjector) -> Router {
194    #[cfg(feature = "data-faker")]
195    register_core_faker_provider();
196
197    Router::new()
198        .route("/ws", get(ws_handler_with_state))
199        .with_state(latency_injector)
200}
201
202/// Build the WebSocket router with proxy handler
203pub fn router_with_proxy(proxy_handler: WsProxyHandler) -> Router {
204    #[cfg(feature = "data-faker")]
205    register_core_faker_provider();
206
207    Router::new()
208        .route("/ws", get(ws_handler_with_proxy))
209        .route("/ws/{*path}", get(ws_handler_with_proxy_path))
210        .with_state(proxy_handler)
211}
212
213/// Build the WebSocket router with handler registry
214pub fn router_with_handlers(registry: std::sync::Arc<HandlerRegistry>) -> Router {
215    #[cfg(feature = "data-faker")]
216    register_core_faker_provider();
217
218    Router::new()
219        .route("/ws", get(ws_handler_with_registry))
220        .route("/ws/{*path}", get(ws_handler_with_registry_path))
221        .with_state(registry)
222}
223
224/// Start WebSocket server with latency simulation
225pub async fn start_with_latency(
226    port: u16,
227    latency: Option<LatencyProfile>,
228) -> Result<(), Box<dyn std::error::Error>> {
229    start_with_latency_and_host(port, "0.0.0.0", latency).await
230}
231
232/// Start WebSocket server with latency simulation and custom host
233pub async fn start_with_latency_and_host(
234    port: u16,
235    host: &str,
236    latency: Option<LatencyProfile>,
237) -> Result<(), Box<dyn std::error::Error>> {
238    let latency_injector = latency.map(|profile| LatencyInjector::new(profile, Default::default()));
239    let router = if let Some(injector) = latency_injector {
240        router_with_latency(injector)
241    } else {
242        router()
243    };
244
245    let addr: std::net::SocketAddr = format!("{}:{}", host, port).parse()?;
246    info!("WebSocket server listening on {}", addr);
247
248    let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
249        format!(
250            "Failed to bind WebSocket server to port {}: {}\n\
251             Hint: The port may already be in use. Try using a different port with --ws-port or check if another process is using this port with: lsof -i :{} or netstat -tulpn | grep {}",
252            port, e, port, port
253        )
254    })?;
255
256    axum::serve(listener, router).await?;
257    Ok(())
258}
259
260// WebSocket handlers
261async fn ws_handler_no_state(ws: WebSocketUpgrade) -> impl IntoResponse {
262    ws.on_upgrade(handle_socket)
263}
264
265async fn ws_handler_with_state(
266    ws: WebSocketUpgrade,
267    axum::extract::State(_latency): axum::extract::State<LatencyInjector>,
268) -> impl IntoResponse {
269    ws.on_upgrade(handle_socket)
270}
271
272async fn ws_handler_with_proxy(
273    ws: WebSocketUpgrade,
274    State(proxy): State<WsProxyHandler>,
275) -> impl IntoResponse {
276    ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, "/ws".to_string()))
277}
278
279async fn ws_handler_with_proxy_path(
280    Path(path): Path<String>,
281    ws: WebSocketUpgrade,
282    State(proxy): State<WsProxyHandler>,
283) -> impl IntoResponse {
284    let full_path = format!("/ws/{}", path);
285    ws.on_upgrade(move |socket| handle_socket_with_proxy(socket, proxy, full_path))
286}
287
288async fn ws_handler_with_registry(
289    ws: WebSocketUpgrade,
290    State(registry): State<std::sync::Arc<HandlerRegistry>>,
291) -> impl IntoResponse {
292    ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, "/ws".to_string()))
293}
294
295async fn ws_handler_with_registry_path(
296    Path(path): Path<String>,
297    ws: WebSocketUpgrade,
298    State(registry): State<std::sync::Arc<HandlerRegistry>>,
299) -> impl IntoResponse {
300    let full_path = format!("/ws/{}", path);
301    ws.on_upgrade(move |socket| handle_socket_with_handlers(socket, registry, full_path))
302}
303
304async fn handle_socket(mut socket: WebSocket) {
305    use std::time::Instant;
306
307    // Track WebSocket connection
308    let registry = get_global_registry();
309    let connection_start = Instant::now();
310    registry.record_ws_connection_established();
311    debug!("WebSocket connection established, tracking metrics");
312
313    // Track connection status (for metrics reporting)
314    let mut status = "normal";
315
316    // Check if replay mode is enabled
317    if let Ok(replay_file) = std::env::var("MOCKFORGE_WS_REPLAY_FILE") {
318        info!("WebSocket replay mode enabled with file: {}", replay_file);
319        handle_socket_with_replay(socket, &replay_file).await;
320    } else {
321        // Normal echo mode
322        while let Some(msg) = socket.recv().await {
323            match msg {
324                Ok(Message::Text(text)) => {
325                    registry.record_ws_message_received();
326
327                    // Echo the message back with "echo: " prefix
328                    let response = format!("echo: {}", text);
329                    if socket.send(Message::Text(response.into())).await.is_err() {
330                        status = "send_error";
331                        break;
332                    }
333                    registry.record_ws_message_sent();
334                }
335                Ok(Message::Close(_)) => {
336                    status = "client_close";
337                    break;
338                }
339                Err(e) => {
340                    error!("WebSocket error: {}", e);
341                    registry.record_ws_error();
342                    status = "error";
343                    break;
344                }
345                _ => {}
346            }
347        }
348    }
349
350    // Connection closed - record duration
351    let duration = connection_start.elapsed().as_secs_f64();
352    registry.record_ws_connection_closed(duration, status);
353    debug!("WebSocket connection closed (status: {}, duration: {:.2}s)", status, duration);
354}
355
356async fn handle_socket_with_replay(mut socket: WebSocket, replay_file: &str) {
357    let _registry = get_global_registry(); // Available for future message tracking
358
359    // Read the replay file
360    let content = match fs::read_to_string(replay_file).await {
361        Ok(content) => content,
362        Err(e) => {
363            error!("Failed to read replay file {}: {}", replay_file, e);
364            return;
365        }
366    };
367
368    // Parse JSONL file
369    let mut replay_entries = Vec::new();
370    for line in content.lines() {
371        if let Ok(entry) = serde_json::from_str::<Value>(line) {
372            replay_entries.push(entry);
373        }
374    }
375
376    info!("Loaded {} replay entries", replay_entries.len());
377
378    // Process replay entries
379    for entry in replay_entries {
380        // Check if we need to wait for a specific message
381        if let Some(wait_for) = entry.get("waitFor") {
382            if let Some(wait_pattern) = wait_for.as_str() {
383                info!("Waiting for pattern: {}", wait_pattern);
384                // Wait for matching message from client
385                let mut found = false;
386                while let Some(msg) = socket.recv().await {
387                    if let Ok(Message::Text(text)) = msg {
388                        if text.contains(wait_pattern) || wait_pattern == "^CLIENT_READY$" {
389                            found = true;
390                            break;
391                        }
392                    }
393                }
394                if !found {
395                    break;
396                }
397            }
398        }
399
400        // Get the message text
401        if let Some(text) = entry.get("text").and_then(|v| v.as_str()) {
402            // Expand tokens if enabled
403            let expanded_text = if std::env::var("MOCKFORGE_RESPONSE_TEMPLATE_EXPAND")
404                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
405                .unwrap_or(false)
406            {
407                expand_tokens(text)
408            } else {
409                text.to_string()
410            };
411
412            info!("Sending replay message: {}", expanded_text);
413            if socket.send(Message::Text(expanded_text.into())).await.is_err() {
414                break;
415            }
416        }
417
418        // Wait for the specified time
419        if let Some(ts) = entry.get("ts").and_then(|v| v.as_u64()) {
420            sleep(Duration::from_millis(ts * 10)).await; // Convert to milliseconds
421        }
422    }
423}
424
425fn expand_tokens(text: &str) -> String {
426    let mut result = text.to_string();
427
428    // Expand {{uuid}}
429    result = result.replace("{{uuid}}", &uuid::Uuid::new_v4().to_string());
430
431    // Expand {{now}}
432    result = result.replace("{{now}}", &chrono::Utc::now().to_rfc3339());
433
434    // Expand {{now+1m}} (add 1 minute)
435    if result.contains("{{now+1m}}") {
436        let now_plus_1m = chrono::Utc::now() + chrono::Duration::minutes(1);
437        result = result.replace("{{now+1m}}", &now_plus_1m.to_rfc3339());
438    }
439
440    // Expand {{now+1h}} (add 1 hour)
441    if result.contains("{{now+1h}}") {
442        let now_plus_1h = chrono::Utc::now() + chrono::Duration::hours(1);
443        result = result.replace("{{now+1h}}", &now_plus_1h.to_rfc3339());
444    }
445
446    // Expand {{randInt min max}}
447    while result.contains("{{randInt") {
448        if let Some(start) = result.find("{{randInt") {
449            if let Some(end) = result[start..].find("}}") {
450                let full_match = &result[start..start + end + 2];
451                let content = &result[start + 9..start + end]; // Skip "{{randInt"
452
453                if let Some(space_pos) = content.find(' ') {
454                    let min_str = &content[..space_pos];
455                    let max_str = &content[space_pos + 1..];
456
457                    if let (Ok(min), Ok(max)) = (min_str.parse::<i32>(), max_str.parse::<i32>()) {
458                        let random_value = fastrand::i32(min..=max);
459                        result = result.replace(full_match, &random_value.to_string());
460                    } else {
461                        result = result.replace(full_match, "0");
462                    }
463                } else {
464                    result = result.replace(full_match, "0");
465                }
466            } else {
467                break;
468            }
469        } else {
470            break;
471        }
472    }
473
474    result
475}
476
477async fn handle_socket_with_proxy(socket: WebSocket, proxy: WsProxyHandler, path: String) {
478    use std::time::Instant;
479
480    let registry = get_global_registry();
481    let connection_start = Instant::now();
482    registry.record_ws_connection_established();
483
484    let mut status = "normal";
485
486    // Check if this connection should be proxied
487    if proxy.config.should_proxy(&path) {
488        info!("Proxying WebSocket connection for path: {}", path);
489        if let Err(e) = proxy.proxy_connection(&path, socket).await {
490            error!("Failed to proxy WebSocket connection: {}", e);
491            registry.record_ws_error();
492            status = "proxy_error";
493        }
494    } else {
495        info!("Handling WebSocket connection locally for path: {}", path);
496        // Handle locally by echoing messages
497        // Note: handle_socket already tracks its own connection metrics,
498        // so we need to avoid double-counting
499        registry.record_ws_connection_closed(0.0, ""); // Decrement the one we just added
500        handle_socket(socket).await;
501        return; // Early return to avoid double-tracking
502    }
503
504    let duration = connection_start.elapsed().as_secs_f64();
505    registry.record_ws_connection_closed(duration, status);
506    debug!(
507        "Proxied WebSocket connection closed (status: {}, duration: {:.2}s)",
508        status, duration
509    );
510}
511
512async fn handle_socket_with_handlers(
513    socket: WebSocket,
514    registry: std::sync::Arc<HandlerRegistry>,
515    path: String,
516) {
517    use std::time::Instant;
518
519    let metrics_registry = get_global_registry();
520    let connection_start = Instant::now();
521    metrics_registry.record_ws_connection_established();
522
523    let mut status = "normal";
524
525    // Generate unique connection ID
526    let connection_id = uuid::Uuid::new_v4().to_string();
527
528    // Get handlers for this path
529    let handlers = registry.get_handlers(&path);
530    if handlers.is_empty() {
531        info!("No handlers found for path: {}, falling back to echo mode", path);
532        metrics_registry.record_ws_connection_closed(0.0, "");
533        handle_socket(socket).await;
534        return;
535    }
536
537    info!(
538        "Handling WebSocket connection with {} handler(s) for path: {}",
539        handlers.len(),
540        path
541    );
542
543    // Create room manager
544    let room_manager = RoomManager::new();
545
546    // Split socket for concurrent send/receive
547    let (mut socket_sender, mut socket_receiver) = socket.split();
548
549    // Create message channel for handlers to send messages
550    let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
551
552    // Create context
553    let mut ctx =
554        WsContext::new(connection_id.clone(), path.clone(), room_manager.clone(), message_tx);
555
556    // Call on_connect for all handlers
557    for handler in &handlers {
558        if let Err(e) = handler.on_connect(&mut ctx).await {
559            error!("Handler on_connect error: {}", e);
560            status = "handler_error";
561        }
562    }
563
564    // Spawn task to send messages from handlers to the socket
565    let send_task = tokio::spawn(async move {
566        while let Some(msg) = message_rx.recv().await {
567            if socket_sender.send(msg).await.is_err() {
568                break;
569            }
570        }
571    });
572
573    // Handle incoming messages
574    while let Some(msg) = socket_receiver.next().await {
575        match msg {
576            Ok(axum_msg) => {
577                metrics_registry.record_ws_message_received();
578
579                let ws_msg: WsMessage = axum_msg.into();
580
581                // Check for close message
582                if matches!(ws_msg, WsMessage::Close) {
583                    status = "client_close";
584                    break;
585                }
586
587                // Pass message through all handlers
588                for handler in &handlers {
589                    if let Err(e) = handler.on_message(&mut ctx, ws_msg.clone()).await {
590                        error!("Handler on_message error: {}", e);
591                        status = "handler_error";
592                    }
593                }
594
595                metrics_registry.record_ws_message_sent();
596            }
597            Err(e) => {
598                error!("WebSocket error: {}", e);
599                metrics_registry.record_ws_error();
600                status = "error";
601                break;
602            }
603        }
604    }
605
606    // Call on_disconnect for all handlers
607    for handler in &handlers {
608        if let Err(e) = handler.on_disconnect(&mut ctx).await {
609            error!("Handler on_disconnect error: {}", e);
610        }
611    }
612
613    // Clean up room memberships
614    let _ = room_manager.leave_all(&connection_id).await;
615
616    // Abort send task
617    send_task.abort();
618
619    let duration = connection_start.elapsed().as_secs_f64();
620    metrics_registry.record_ws_connection_closed(duration, status);
621    debug!(
622        "Handler-based WebSocket connection closed (status: {}, duration: {:.2}s)",
623        status, duration
624    );
625}
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630
631    #[test]
632    fn test_router_creation() {
633        let _router = router();
634        // Router should be created successfully
635    }
636
637    #[test]
638    fn test_router_with_latency_creation() {
639        let latency_profile = LatencyProfile::default();
640        let latency_injector = LatencyInjector::new(latency_profile, Default::default());
641        let _router = router_with_latency(latency_injector);
642        // Router should be created successfully
643    }
644
645    #[test]
646    fn test_router_with_proxy_creation() {
647        let config = mockforge_core::WsProxyConfig {
648            upstream_url: "ws://localhost:8080".to_string(),
649            ..Default::default()
650        };
651        let proxy_handler = WsProxyHandler::new(config);
652        let _router = router_with_proxy(proxy_handler);
653        // Router should be created successfully
654    }
655
656    #[tokio::test]
657    async fn test_start_with_latency_config_none() {
658        // Test that we can create the router without latency
659        let result = std::panic::catch_unwind(|| {
660            let _router = router();
661        });
662        assert!(result.is_ok());
663    }
664
665    #[tokio::test]
666    async fn test_start_with_latency_config_some() {
667        // Test that we can create the router with latency
668        let latency_profile = LatencyProfile::default();
669        let latency_injector = LatencyInjector::new(latency_profile, Default::default());
670
671        let result = std::panic::catch_unwind(|| {
672            let _router = router_with_latency(latency_injector);
673        });
674        assert!(result.is_ok());
675    }
676}