mockforge_recorder/protocols/
websocket.rs1use crate::{models::*, recorder::Recorder};
4use chrono::Utc;
5use std::collections::HashMap;
6use tracing::debug;
7use uuid::Uuid;
8
9pub async fn record_ws_connection(
11 recorder: &Recorder,
12 path: &str,
13 headers: &HashMap<String, String>,
14 client_ip: Option<&str>,
15 trace_id: Option<&str>,
16 span_id: Option<&str>,
17) -> Result<String, crate::RecorderError> {
18 let request_id = Uuid::new_v4().to_string();
19
20 let request = RecordedRequest {
21 id: request_id.clone(),
22 protocol: Protocol::WebSocket,
23 timestamp: Utc::now(),
24 method: "CONNECT".to_string(),
25 path: path.to_string(),
26 query_params: None,
27 headers: serde_json::to_string(&headers)?,
28 body: None,
29 body_encoding: "utf8".to_string(),
30 client_ip: client_ip.map(|s| s.to_string()),
31 trace_id: trace_id.map(|s| s.to_string()),
32 span_id: span_id.map(|s| s.to_string()),
33 duration_ms: None,
34 status_code: Some(101), tags: Some(serde_json::to_string(&vec!["websocket", "connection"])?),
36 };
37
38 recorder.record_request(request).await?;
39 debug!("Recorded WebSocket connection: {} {}", request_id, path);
40
41 Ok(request_id)
42}
43
44pub async fn record_ws_message(
46 recorder: &Recorder,
47 connection_id: &str,
48 direction: &str, message: &[u8],
50 is_binary: bool,
51 trace_id: Option<&str>,
52 span_id: Option<&str>,
53) -> Result<String, crate::RecorderError> {
54 let message_id = Uuid::new_v4().to_string();
55
56 let (body_str, body_encoding) = if is_binary {
57 let encoded = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, message);
58 (Some(encoded), "base64".to_string())
59 } else {
60 match std::str::from_utf8(message) {
61 Ok(text) => (Some(text.to_string()), "utf8".to_string()),
62 Err(_) => {
63 let encoded =
64 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, message);
65 (Some(encoded), "base64".to_string())
66 }
67 }
68 };
69
70 let tags = vec!["websocket", "message", direction];
71
72 let request = RecordedRequest {
73 id: message_id.clone(),
74 protocol: Protocol::WebSocket,
75 timestamp: Utc::now(),
76 method: direction.to_uppercase(),
77 path: format!("/ws/{}", connection_id),
78 query_params: None,
79 headers: serde_json::to_string(&HashMap::from([(
80 "ws-connection-id".to_string(),
81 connection_id.to_string(),
82 )]))?,
83 body: body_str,
84 body_encoding,
85 client_ip: None,
86 trace_id: trace_id.map(|s| s.to_string()),
87 span_id: span_id.map(|s| s.to_string()),
88 duration_ms: None,
89 status_code: Some(200),
90 tags: Some(serde_json::to_string(&tags)?),
91 };
92
93 recorder.record_request(request).await?;
94 debug!(
95 "Recorded WebSocket message: {} {} {} bytes",
96 message_id,
97 direction,
98 message.len()
99 );
100
101 Ok(message_id)
102}
103
104pub async fn record_ws_disconnection(
106 recorder: &Recorder,
107 connection_id: &str,
108 reason: Option<&str>,
109 duration_ms: i64,
110) -> Result<(), crate::RecorderError> {
111 let disconnect_id = Uuid::new_v4().to_string();
112
113 let request = RecordedRequest {
114 id: disconnect_id.clone(),
115 protocol: Protocol::WebSocket,
116 timestamp: Utc::now(),
117 method: "DISCONNECT".to_string(),
118 path: format!("/ws/{}", connection_id),
119 query_params: None,
120 headers: serde_json::to_string(&HashMap::from([(
121 "ws-connection-id".to_string(),
122 connection_id.to_string(),
123 )]))?,
124 body: reason.map(|r| r.to_string()),
125 body_encoding: "utf8".to_string(),
126 client_ip: None,
127 trace_id: None,
128 span_id: None,
129 duration_ms: Some(duration_ms),
130 status_code: Some(1000), tags: Some(serde_json::to_string(&vec!["websocket", "disconnection"])?),
132 };
133
134 recorder.record_request(request).await?;
135 debug!(
136 "Recorded WebSocket disconnection: {} connection={} duration={}ms",
137 disconnect_id, connection_id, duration_ms
138 );
139
140 Ok(())
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::database::RecorderDatabase;
147
148 #[tokio::test]
149 async fn test_record_ws_connection() {
150 let db = RecorderDatabase::new_in_memory().await.unwrap();
151 let recorder = Recorder::new(db);
152
153 let headers = HashMap::from([
154 ("upgrade".to_string(), "websocket".to_string()),
155 ("connection".to_string(), "Upgrade".to_string()),
156 ]);
157
158 let connection_id =
159 record_ws_connection(&recorder, "/ws/chat", &headers, Some("127.0.0.1"), None, None)
160 .await
161 .unwrap();
162
163 let exchange = recorder.database().get_exchange(&connection_id).await.unwrap();
165 assert!(exchange.is_some());
166
167 let exchange = exchange.unwrap();
168 assert_eq!(exchange.request.protocol, Protocol::WebSocket);
169 assert_eq!(exchange.request.method, "CONNECT");
170 }
171
172 #[tokio::test]
173 async fn test_record_ws_message() {
174 let db = RecorderDatabase::new_in_memory().await.unwrap();
175 let recorder = Recorder::new(db);
176
177 let message_id = record_ws_message(
178 &recorder,
179 "conn-123",
180 "inbound",
181 b"Hello, WebSocket!",
182 false,
183 None,
184 None,
185 )
186 .await
187 .unwrap();
188
189 let exchange = recorder.database().get_exchange(&message_id).await.unwrap();
191 assert!(exchange.is_some());
192
193 let exchange = exchange.unwrap();
194 assert_eq!(exchange.request.protocol, Protocol::WebSocket);
195 assert_eq!(exchange.request.method, "INBOUND");
196 }
197}