Skip to main content

cortex_runtime/live/
websocket.rs

1//! Native WebSocket client for real-time page interaction.
2//!
3//! Provides a thin wrapper around `tokio-tungstenite` for opening, sending,
4//! receiving, and closing WebSocket connections discovered by
5//! [`crate::acquisition::ws_discovery`]. This avoids spinning up a browser
6//! just to interact with sites that use WebSockets for their primary data
7//! transport (Slack, Discord, real-time dashboards, etc.).
8//!
9//! ## Execution priority
10//!
11//! In the action execution stack, WebSocket sits at priority 4
12//! (after WebMCP, Platform API, and HTTP Action).
13
14use anyhow::{bail, Result};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use tokio::sync::Mutex;
18
19// Re-export discovery types for convenience.
20pub use crate::acquisition::ws_discovery::{WsAuth, WsEndpoint, WsProtocol};
21
22/// An active WebSocket session.
23///
24/// Wraps a `tokio-tungstenite` connection with session metadata
25/// (cookies, auth tokens, protocol details).
26pub struct WsSession {
27    /// The WebSocket URL this session is connected to.
28    pub url: String,
29    /// The protocol used (Raw, Socket.IO, SockJS, SignalR).
30    pub protocol: WsProtocol,
31    /// Domain of the connected site.
32    pub domain: String,
33    /// Whether the connection is currently open.
34    connected: bool,
35    /// Message history (most recent messages, bounded).
36    messages: Vec<WsMessage>,
37    /// Maximum messages to keep in history.
38    max_history: usize,
39    /// Internal sink/stream — wrapped in Mutex for interior mutability.
40    _inner: Mutex<Option<WsInner>>,
41}
42
43/// Internal WebSocket connection state.
44struct WsInner {
45    /// The underlying tungstenite write half.
46    sink: futures::stream::SplitSink<
47        tokio_tungstenite::WebSocketStream<
48            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
49        >,
50        tokio_tungstenite::tungstenite::Message,
51    >,
52    /// The underlying tungstenite read half.
53    stream: futures::stream::SplitStream<
54        tokio_tungstenite::WebSocketStream<
55            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
56        >,
57    >,
58}
59
60/// A WebSocket message received or sent.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct WsMessage {
63    /// Direction of the message.
64    pub direction: WsDirection,
65    /// Message payload (text or JSON string).
66    pub payload: String,
67    /// Timestamp (milliseconds since session start).
68    pub timestamp_ms: u64,
69}
70
71/// Direction of a WebSocket message.
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
73pub enum WsDirection {
74    /// Sent from client to server.
75    Sent,
76    /// Received from server.
77    Received,
78}
79
80impl WsSession {
81    /// Connect to a WebSocket endpoint.
82    ///
83    /// Builds the connection URL and optional headers (cookies, auth tokens),
84    /// then opens the WebSocket connection.
85    pub async fn connect(endpoint: &WsEndpoint, cookies: &HashMap<String, String>) -> Result<Self> {
86        use futures::StreamExt;
87        use tokio_tungstenite::tungstenite::http::Request;
88
89        // Build the WebSocket URL.
90        let ws_url = &endpoint.url;
91
92        // Build the HTTP request with cookies/auth headers.
93        let mut request_builder = Request::builder().uri(ws_url);
94
95        // Add cookies as a Cookie header.
96        if !cookies.is_empty() {
97            let cookie_str: String = cookies
98                .iter()
99                .map(|(k, v)| format!("{k}={v}"))
100                .collect::<Vec<_>>()
101                .join("; ");
102            request_builder = request_builder.header("Cookie", cookie_str);
103        }
104
105        // Add origin header (many WS servers require it).
106        let origin = if let Ok(parsed) = url::Url::parse(ws_url) {
107            format!(
108                "{}://{}",
109                if parsed.scheme() == "wss" {
110                    "https"
111                } else {
112                    "http"
113                },
114                parsed.host_str().unwrap_or("localhost")
115            )
116        } else {
117            "https://localhost".to_string()
118        };
119        request_builder = request_builder.header("Origin", &origin);
120
121        let request = request_builder
122            .body(())
123            .map_err(|e| anyhow::anyhow!("failed to build WS request: {e}"))?;
124
125        // Connect.
126        let (ws_stream, _response) = tokio_tungstenite::connect_async(request)
127            .await
128            .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {e}"))?;
129
130        let (sink, stream) = ws_stream.split();
131
132        let domain = url::Url::parse(ws_url)
133            .ok()
134            .and_then(|u| u.host_str().map(|s| s.to_string()))
135            .unwrap_or_default();
136
137        Ok(WsSession {
138            url: ws_url.clone(),
139            protocol: endpoint.protocol.clone(),
140            domain,
141            connected: true,
142            messages: Vec::new(),
143            max_history: 1000,
144            _inner: Mutex::new(Some(WsInner { sink, stream })),
145        })
146    }
147
148    /// Send a JSON-serializable message over the WebSocket.
149    ///
150    /// For Socket.IO, wraps the message in the appropriate frame format.
151    /// For raw WebSocket, sends as-is.
152    pub async fn send_json<T: Serialize>(&mut self, msg: &T) -> Result<()> {
153        use futures::SinkExt;
154        use tokio_tungstenite::tungstenite::Message;
155
156        if !self.connected {
157            bail!("WebSocket is not connected");
158        }
159
160        let payload = serde_json::to_string(msg)?;
161
162        // Wrap for Socket.IO if needed.
163        let wire_payload = match &self.protocol {
164            WsProtocol::SocketIO => format!("42{payload}"),
165            _ => payload.clone(),
166        };
167
168        let mut inner_guard = self._inner.lock().await;
169        if let Some(inner) = inner_guard.as_mut() {
170            inner
171                .sink
172                .send(Message::Text(wire_payload))
173                .await
174                .map_err(|e| anyhow::anyhow!("failed to send WS message: {e}"))?;
175        } else {
176            bail!("WebSocket connection not available");
177        }
178        drop(inner_guard);
179
180        self.messages.push(WsMessage {
181            direction: WsDirection::Sent,
182            payload,
183            timestamp_ms: 0, // Caller can set real timestamps.
184        });
185
186        // Trim history.
187        if self.messages.len() > self.max_history {
188            let drain = self.messages.len() - self.max_history;
189            self.messages.drain(..drain);
190        }
191
192        Ok(())
193    }
194
195    /// Receive the next message from the WebSocket.
196    ///
197    /// Returns `None` if the connection is closed. Automatically skips
198    /// ping/pong control frames and returns the next data message.
199    pub async fn receive(&mut self) -> Result<Option<String>> {
200        use futures::StreamExt;
201        use tokio_tungstenite::tungstenite::Message;
202
203        loop {
204            if !self.connected {
205                return Ok(None);
206            }
207
208            let mut inner_guard = self._inner.lock().await;
209            let inner = match inner_guard.as_mut() {
210                Some(i) => i,
211                None => return Ok(None),
212            };
213
214            match inner.stream.next().await {
215                Some(Ok(Message::Text(text))) => {
216                    // Unwrap Socket.IO frame if needed.
217                    let payload = match &self.protocol {
218                        WsProtocol::SocketIO => text
219                            .strip_prefix("42")
220                            .map(|s| s.to_string())
221                            .unwrap_or(text),
222                        _ => text,
223                    };
224
225                    drop(inner_guard);
226
227                    self.messages.push(WsMessage {
228                        direction: WsDirection::Received,
229                        payload: payload.clone(),
230                        timestamp_ms: 0,
231                    });
232
233                    if self.messages.len() > self.max_history {
234                        let drain = self.messages.len() - self.max_history;
235                        self.messages.drain(..drain);
236                    }
237
238                    return Ok(Some(payload));
239                }
240                Some(Ok(Message::Binary(data))) => {
241                    drop(inner_guard);
242                    return Ok(Some(format!("[binary: {} bytes]", data.len())));
243                }
244                Some(Ok(Message::Close(_))) => {
245                    drop(inner_guard);
246                    self.connected = false;
247                    return Ok(None);
248                }
249                Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {
250                    // Control frames — skip and loop for next real message.
251                    drop(inner_guard);
252                    continue;
253                }
254                Some(Err(e)) => {
255                    drop(inner_guard);
256                    self.connected = false;
257                    bail!("WebSocket error: {e}");
258                }
259                None => {
260                    drop(inner_guard);
261                    self.connected = false;
262                    return Ok(None);
263                }
264            }
265        }
266    }
267
268    /// Receive messages for a given duration, returning all collected messages.
269    pub async fn watch(&mut self, duration_ms: u64) -> Result<Vec<WsMessage>> {
270        let mut collected = Vec::new();
271        let deadline =
272            tokio::time::Instant::now() + tokio::time::Duration::from_millis(duration_ms);
273
274        loop {
275            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
276            if remaining.is_zero() {
277                break;
278            }
279
280            match tokio::time::timeout(remaining, self.receive()).await {
281                Ok(Ok(Some(payload))) => {
282                    collected.push(WsMessage {
283                        direction: WsDirection::Received,
284                        payload,
285                        timestamp_ms: 0,
286                    });
287                }
288                Ok(Ok(None)) => break, // Connection closed.
289                Ok(Err(_)) => break,   // Error.
290                Err(_) => break,       // Timeout reached.
291            }
292        }
293
294        Ok(collected)
295    }
296
297    /// Close the WebSocket connection gracefully.
298    pub async fn close(&mut self) -> Result<()> {
299        use futures::SinkExt;
300        use tokio_tungstenite::tungstenite::Message;
301
302        if !self.connected {
303            return Ok(());
304        }
305
306        let mut inner_guard = self._inner.lock().await;
307        if let Some(inner) = inner_guard.as_mut() {
308            inner.sink.send(Message::Close(None)).await.ok();
309        }
310        *inner_guard = None;
311        drop(inner_guard);
312
313        self.connected = false;
314        Ok(())
315    }
316
317    /// Whether the connection is currently open.
318    pub fn is_connected(&self) -> bool {
319        self.connected
320    }
321
322    /// Get the message history.
323    pub fn history(&self) -> &[WsMessage] {
324        &self.messages
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_ws_message_serde() {
334        let msg = WsMessage {
335            direction: WsDirection::Received,
336            payload: r#"{"type":"update","data":42}"#.to_string(),
337            timestamp_ms: 12345,
338        };
339
340        let json = serde_json::to_string(&msg).unwrap();
341        let parsed: WsMessage = serde_json::from_str(&json).unwrap();
342        assert_eq!(parsed.direction, WsDirection::Received);
343        assert_eq!(parsed.timestamp_ms, 12345);
344        assert!(parsed.payload.contains("update"));
345    }
346
347    #[test]
348    fn test_ws_direction_eq() {
349        assert_eq!(WsDirection::Sent, WsDirection::Sent);
350        assert_ne!(WsDirection::Sent, WsDirection::Received);
351    }
352}