cortex_runtime/live/
websocket.rs1use anyhow::{bail, Result};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use tokio::sync::Mutex;
18
19pub use crate::acquisition::ws_discovery::{WsAuth, WsEndpoint, WsProtocol};
21
22pub struct WsSession {
27 pub url: String,
29 pub protocol: WsProtocol,
31 pub domain: String,
33 connected: bool,
35 messages: Vec<WsMessage>,
37 max_history: usize,
39 _inner: Mutex<Option<WsInner>>,
41}
42
43struct WsInner {
45 sink: futures::stream::SplitSink<
47 tokio_tungstenite::WebSocketStream<
48 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
49 >,
50 tokio_tungstenite::tungstenite::Message,
51 >,
52 stream: futures::stream::SplitStream<
54 tokio_tungstenite::WebSocketStream<
55 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
56 >,
57 >,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct WsMessage {
63 pub direction: WsDirection,
65 pub payload: String,
67 pub timestamp_ms: u64,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
73pub enum WsDirection {
74 Sent,
76 Received,
78}
79
80impl WsSession {
81 pub async fn connect(endpoint: &WsEndpoint, cookies: &HashMap<String, String>) -> Result<Self> {
86 use futures::StreamExt;
87 use tokio_tungstenite::tungstenite::http::Request;
88
89 let ws_url = &endpoint.url;
91
92 let mut request_builder = Request::builder().uri(ws_url);
94
95 if !cookies.is_empty() {
97 let cookie_str: String = cookies
98 .iter()
99 .map(|(k, v)| format!("{k}={v}"))
100 .collect::<Vec<_>>()
101 .join("; ");
102 request_builder = request_builder.header("Cookie", cookie_str);
103 }
104
105 let origin = if let Ok(parsed) = url::Url::parse(ws_url) {
107 format!(
108 "{}://{}",
109 if parsed.scheme() == "wss" {
110 "https"
111 } else {
112 "http"
113 },
114 parsed.host_str().unwrap_or("localhost")
115 )
116 } else {
117 "https://localhost".to_string()
118 };
119 request_builder = request_builder.header("Origin", &origin);
120
121 let request = request_builder
122 .body(())
123 .map_err(|e| anyhow::anyhow!("failed to build WS request: {e}"))?;
124
125 let (ws_stream, _response) = tokio_tungstenite::connect_async(request)
127 .await
128 .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {e}"))?;
129
130 let (sink, stream) = ws_stream.split();
131
132 let domain = url::Url::parse(ws_url)
133 .ok()
134 .and_then(|u| u.host_str().map(|s| s.to_string()))
135 .unwrap_or_default();
136
137 Ok(WsSession {
138 url: ws_url.clone(),
139 protocol: endpoint.protocol.clone(),
140 domain,
141 connected: true,
142 messages: Vec::new(),
143 max_history: 1000,
144 _inner: Mutex::new(Some(WsInner { sink, stream })),
145 })
146 }
147
148 pub async fn send_json<T: Serialize>(&mut self, msg: &T) -> Result<()> {
153 use futures::SinkExt;
154 use tokio_tungstenite::tungstenite::Message;
155
156 if !self.connected {
157 bail!("WebSocket is not connected");
158 }
159
160 let payload = serde_json::to_string(msg)?;
161
162 let wire_payload = match &self.protocol {
164 WsProtocol::SocketIO => format!("42{payload}"),
165 _ => payload.clone(),
166 };
167
168 let mut inner_guard = self._inner.lock().await;
169 if let Some(inner) = inner_guard.as_mut() {
170 inner
171 .sink
172 .send(Message::Text(wire_payload))
173 .await
174 .map_err(|e| anyhow::anyhow!("failed to send WS message: {e}"))?;
175 } else {
176 bail!("WebSocket connection not available");
177 }
178 drop(inner_guard);
179
180 self.messages.push(WsMessage {
181 direction: WsDirection::Sent,
182 payload,
183 timestamp_ms: 0, });
185
186 if self.messages.len() > self.max_history {
188 let drain = self.messages.len() - self.max_history;
189 self.messages.drain(..drain);
190 }
191
192 Ok(())
193 }
194
195 pub async fn receive(&mut self) -> Result<Option<String>> {
200 use futures::StreamExt;
201 use tokio_tungstenite::tungstenite::Message;
202
203 loop {
204 if !self.connected {
205 return Ok(None);
206 }
207
208 let mut inner_guard = self._inner.lock().await;
209 let inner = match inner_guard.as_mut() {
210 Some(i) => i,
211 None => return Ok(None),
212 };
213
214 match inner.stream.next().await {
215 Some(Ok(Message::Text(text))) => {
216 let payload = match &self.protocol {
218 WsProtocol::SocketIO => text
219 .strip_prefix("42")
220 .map(|s| s.to_string())
221 .unwrap_or(text),
222 _ => text,
223 };
224
225 drop(inner_guard);
226
227 self.messages.push(WsMessage {
228 direction: WsDirection::Received,
229 payload: payload.clone(),
230 timestamp_ms: 0,
231 });
232
233 if self.messages.len() > self.max_history {
234 let drain = self.messages.len() - self.max_history;
235 self.messages.drain(..drain);
236 }
237
238 return Ok(Some(payload));
239 }
240 Some(Ok(Message::Binary(data))) => {
241 drop(inner_guard);
242 return Ok(Some(format!("[binary: {} bytes]", data.len())));
243 }
244 Some(Ok(Message::Close(_))) => {
245 drop(inner_guard);
246 self.connected = false;
247 return Ok(None);
248 }
249 Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {
250 drop(inner_guard);
252 continue;
253 }
254 Some(Err(e)) => {
255 drop(inner_guard);
256 self.connected = false;
257 bail!("WebSocket error: {e}");
258 }
259 None => {
260 drop(inner_guard);
261 self.connected = false;
262 return Ok(None);
263 }
264 }
265 }
266 }
267
268 pub async fn watch(&mut self, duration_ms: u64) -> Result<Vec<WsMessage>> {
270 let mut collected = Vec::new();
271 let deadline =
272 tokio::time::Instant::now() + tokio::time::Duration::from_millis(duration_ms);
273
274 loop {
275 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
276 if remaining.is_zero() {
277 break;
278 }
279
280 match tokio::time::timeout(remaining, self.receive()).await {
281 Ok(Ok(Some(payload))) => {
282 collected.push(WsMessage {
283 direction: WsDirection::Received,
284 payload,
285 timestamp_ms: 0,
286 });
287 }
288 Ok(Ok(None)) => break, Ok(Err(_)) => break, Err(_) => break, }
292 }
293
294 Ok(collected)
295 }
296
297 pub async fn close(&mut self) -> Result<()> {
299 use futures::SinkExt;
300 use tokio_tungstenite::tungstenite::Message;
301
302 if !self.connected {
303 return Ok(());
304 }
305
306 let mut inner_guard = self._inner.lock().await;
307 if let Some(inner) = inner_guard.as_mut() {
308 inner.sink.send(Message::Close(None)).await.ok();
309 }
310 *inner_guard = None;
311 drop(inner_guard);
312
313 self.connected = false;
314 Ok(())
315 }
316
317 pub fn is_connected(&self) -> bool {
319 self.connected
320 }
321
322 pub fn history(&self) -> &[WsMessage] {
324 &self.messages
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_ws_message_serde() {
334 let msg = WsMessage {
335 direction: WsDirection::Received,
336 payload: r#"{"type":"update","data":42}"#.to_string(),
337 timestamp_ms: 12345,
338 };
339
340 let json = serde_json::to_string(&msg).unwrap();
341 let parsed: WsMessage = serde_json::from_str(&json).unwrap();
342 assert_eq!(parsed.direction, WsDirection::Received);
343 assert_eq!(parsed.timestamp_ms, 12345);
344 assert!(parsed.payload.contains("update"));
345 }
346
347 #[test]
348 fn test_ws_direction_eq() {
349 assert_eq!(WsDirection::Sent, WsDirection::Sent);
350 assert_ne!(WsDirection::Sent, WsDirection::Received);
351 }
352}