mockforge_recorder/protocols/
websocket.rs

1//! WebSocket recording helpers
2
3use crate::{models::*, recorder::Recorder};
4use chrono::Utc;
5use std::collections::HashMap;
6use tracing::debug;
7use uuid::Uuid;
8
9/// Record a WebSocket connection request
10pub async fn record_ws_connection(
11    recorder: &Recorder,
12    path: &str,
13    headers: &HashMap<String, String>,
14    client_ip: Option<&str>,
15    trace_id: Option<&str>,
16    span_id: Option<&str>,
17) -> Result<String, crate::RecorderError> {
18    let request_id = Uuid::new_v4().to_string();
19
20    let request = RecordedRequest {
21        id: request_id.clone(),
22        protocol: Protocol::WebSocket,
23        timestamp: Utc::now(),
24        method: "CONNECT".to_string(),
25        path: path.to_string(),
26        query_params: None,
27        headers: serde_json::to_string(&headers)?,
28        body: None,
29        body_encoding: "utf8".to_string(),
30        client_ip: client_ip.map(|s| s.to_string()),
31        trace_id: trace_id.map(|s| s.to_string()),
32        span_id: span_id.map(|s| s.to_string()),
33        duration_ms: None,
34        status_code: Some(101), // Switching Protocols
35        tags: Some(serde_json::to_string(&vec!["websocket", "connection"])?),
36    };
37
38    recorder.record_request(request).await?;
39    debug!("Recorded WebSocket connection: {} {}", request_id, path);
40
41    Ok(request_id)
42}
43
44/// Record a WebSocket message
45pub async fn record_ws_message(
46    recorder: &Recorder,
47    connection_id: &str,
48    direction: &str, // "inbound" or "outbound"
49    message: &[u8],
50    is_binary: bool,
51    trace_id: Option<&str>,
52    span_id: Option<&str>,
53) -> Result<String, crate::RecorderError> {
54    let message_id = Uuid::new_v4().to_string();
55
56    let (body_str, body_encoding) = if is_binary {
57        let encoded = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, message);
58        (Some(encoded), "base64".to_string())
59    } else {
60        match std::str::from_utf8(message) {
61            Ok(text) => (Some(text.to_string()), "utf8".to_string()),
62            Err(_) => {
63                let encoded =
64                    base64::Engine::encode(&base64::engine::general_purpose::STANDARD, message);
65                (Some(encoded), "base64".to_string())
66            }
67        }
68    };
69
70    let tags = vec!["websocket", "message", direction];
71
72    let request = RecordedRequest {
73        id: message_id.clone(),
74        protocol: Protocol::WebSocket,
75        timestamp: Utc::now(),
76        method: direction.to_uppercase(),
77        path: format!("/ws/{}", connection_id),
78        query_params: None,
79        headers: serde_json::to_string(&HashMap::from([(
80            "ws-connection-id".to_string(),
81            connection_id.to_string(),
82        )]))?,
83        body: body_str,
84        body_encoding,
85        client_ip: None,
86        trace_id: trace_id.map(|s| s.to_string()),
87        span_id: span_id.map(|s| s.to_string()),
88        duration_ms: None,
89        status_code: Some(200),
90        tags: Some(serde_json::to_string(&tags)?),
91    };
92
93    recorder.record_request(request).await?;
94    debug!(
95        "Recorded WebSocket message: {} {} {} bytes",
96        message_id,
97        direction,
98        message.len()
99    );
100
101    Ok(message_id)
102}
103
104/// Record WebSocket disconnection
105pub async fn record_ws_disconnection(
106    recorder: &Recorder,
107    connection_id: &str,
108    reason: Option<&str>,
109    duration_ms: i64,
110) -> Result<(), crate::RecorderError> {
111    let disconnect_id = Uuid::new_v4().to_string();
112
113    let request = RecordedRequest {
114        id: disconnect_id.clone(),
115        protocol: Protocol::WebSocket,
116        timestamp: Utc::now(),
117        method: "DISCONNECT".to_string(),
118        path: format!("/ws/{}", connection_id),
119        query_params: None,
120        headers: serde_json::to_string(&HashMap::from([(
121            "ws-connection-id".to_string(),
122            connection_id.to_string(),
123        )]))?,
124        body: reason.map(|r| r.to_string()),
125        body_encoding: "utf8".to_string(),
126        client_ip: None,
127        trace_id: None,
128        span_id: None,
129        duration_ms: Some(duration_ms),
130        status_code: Some(1000), // Normal closure
131        tags: Some(serde_json::to_string(&vec!["websocket", "disconnection"])?),
132    };
133
134    recorder.record_request(request).await?;
135    debug!(
136        "Recorded WebSocket disconnection: {} connection={} duration={}ms",
137        disconnect_id, connection_id, duration_ms
138    );
139
140    Ok(())
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::database::RecorderDatabase;
147
148    #[tokio::test]
149    async fn test_record_ws_connection() {
150        let db = RecorderDatabase::new_in_memory().await.unwrap();
151        let recorder = Recorder::new(db);
152
153        let headers = HashMap::from([
154            ("upgrade".to_string(), "websocket".to_string()),
155            ("connection".to_string(), "Upgrade".to_string()),
156        ]);
157
158        let connection_id =
159            record_ws_connection(&recorder, "/ws/chat", &headers, Some("127.0.0.1"), None, None)
160                .await
161                .unwrap();
162
163        // Verify it was recorded
164        let exchange = recorder.database().get_exchange(&connection_id).await.unwrap();
165        assert!(exchange.is_some());
166
167        let exchange = exchange.unwrap();
168        assert_eq!(exchange.request.protocol, Protocol::WebSocket);
169        assert_eq!(exchange.request.method, "CONNECT");
170    }
171
172    #[tokio::test]
173    async fn test_record_ws_message() {
174        let db = RecorderDatabase::new_in_memory().await.unwrap();
175        let recorder = Recorder::new(db);
176
177        let message_id = record_ws_message(
178            &recorder,
179            "conn-123",
180            "inbound",
181            b"Hello, WebSocket!",
182            false,
183            None,
184            None,
185        )
186        .await
187        .unwrap();
188
189        // Verify it was recorded
190        let exchange = recorder.database().get_exchange(&message_id).await.unwrap();
191        assert!(exchange.is_some());
192
193        let exchange = exchange.unwrap();
194        assert_eq!(exchange.request.protocol, Protocol::WebSocket);
195        assert_eq!(exchange.request.method, "INBOUND");
196    }
197}