Skip to main content

cortex_runtime/acquisition/
ws_discovery.rs

1//! Discovers WebSocket endpoints from HTML and JavaScript source code.
2//!
3//! Scans for WebSocket connection patterns in page source and JS bundles:
4//!
5//! 1. **Standard WebSocket** — `new WebSocket("wss://...")`
6//! 2. **Socket.IO** — `io("wss://...",` or `io.connect("...")`
7//! 3. **SockJS** — `new SockJS("...")`
8//! 4. **SignalR** — `new signalR.HubConnectionBuilder().withUrl("...")`
9//!
10//! Also checks known platform configurations for major real-time apps
11//! (Slack, Discord, etc.) via an embedded JSON registry.
12
13use regex::Regex;
14use serde::{Deserialize, Serialize};
15use std::collections::HashSet;
16use std::sync::OnceLock;
17
18const WS_PLATFORMS_JSON: &str = include_str!("ws_platforms.json");
19
20/// The WebSocket protocol/library in use.
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
22pub enum WsProtocol {
23    /// Standard WebSocket API.
24    Raw,
25    /// Socket.IO protocol.
26    SocketIO,
27    /// SockJS protocol.
28    SockJS,
29    /// ASP.NET SignalR.
30    SignalR,
31    /// Unknown protocol wrapper.
32    Unknown,
33}
34
35/// How the WebSocket connection authenticates.
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub enum WsAuth {
38    /// Auth via cookies (sent automatically).
39    Cookie,
40    /// Auth via token in query parameter.
41    QueryParam,
42    /// Auth via token in the first message.
43    FirstMessage,
44    /// Auth via HTTP header (upgrade request).
45    Header,
46    /// No authentication required.
47    None,
48}
49
50/// A discovered WebSocket endpoint.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct WsEndpoint {
53    /// The WebSocket URL (wss:// or ws://).
54    pub url: String,
55    /// The protocol/library used.
56    pub protocol: WsProtocol,
57    /// Authentication method.
58    pub auth_method: WsAuth,
59    /// Which source this was discovered from.
60    pub discovered_from: String,
61    /// Confidence that this is a real endpoint, in [0.0, 1.0].
62    pub confidence: f32,
63}
64
65// ── Platform configuration types ────────────────────────────────────────────
66
67#[derive(Debug, Clone, Deserialize)]
68pub(crate) struct WsPlatformConfig {
69    ws_url_pattern: Option<String>,
70    ws_url: Option<String>,
71    protocol: String,
72    auth: String,
73    #[allow(dead_code)]
74    ping: Option<serde_json::Value>,
75    #[allow(dead_code)]
76    auth_message: Option<serde_json::Value>,
77    #[allow(dead_code)]
78    heartbeat: Option<serde_json::Value>,
79    #[allow(dead_code)]
80    send_message: Option<serde_json::Value>,
81}
82
83type WsPlatformRegistry = std::collections::HashMap<String, WsPlatformConfig>;
84
85fn ws_platform_registry() -> &'static WsPlatformRegistry {
86    static REGISTRY: OnceLock<WsPlatformRegistry> = OnceLock::new();
87    REGISTRY.get_or_init(|| serde_json::from_str(WS_PLATFORMS_JSON).unwrap_or_default())
88}
89
90// ── Public API ──────────────────────────────────────────────────────────────
91
92/// Discover WebSocket endpoints from HTML source and JS bundles.
93///
94/// Scans for `new WebSocket(...)`, Socket.IO, SockJS, and SignalR patterns.
95/// Also checks the known platform registry.
96///
97/// # Arguments
98///
99/// * `html` - Raw HTML source of the page.
100/// * `js_bundles` - JavaScript bundle source strings to scan.
101/// * `domain` - The domain being mapped (for platform lookup).
102///
103/// # Returns
104///
105/// A vector of discovered [`WsEndpoint`] items, deduplicated by URL.
106pub fn discover_ws_endpoints(html: &str, js_bundles: &[String], domain: &str) -> Vec<WsEndpoint> {
107    let mut endpoints = Vec::new();
108    let mut seen_urls: HashSet<String> = HashSet::new();
109
110    // Check known platforms first
111    for (platform_domain, config) in ws_platform_registry() {
112        if domain.contains(platform_domain.as_str()) || platform_domain.contains(domain) {
113            let url = config
114                .ws_url
115                .clone()
116                .or_else(|| config.ws_url_pattern.clone())
117                .unwrap_or_default();
118            if !url.is_empty() && seen_urls.insert(url.clone()) {
119                let protocol = parse_protocol(&config.protocol);
120                let auth_method = parse_auth(&config.auth);
121                endpoints.push(WsEndpoint {
122                    url,
123                    protocol,
124                    auth_method,
125                    discovered_from: format!("platform:{platform_domain}"),
126                    confidence: 0.95,
127                });
128            }
129        }
130    }
131
132    // Scan HTML and JS sources
133    let sources: Vec<(&str, String)> = std::iter::once((html, "html".to_string()))
134        .chain(
135            js_bundles
136                .iter()
137                .enumerate()
138                .map(|(i, s)| (s.as_str(), format!("js_bundle_{i}"))),
139        )
140        .collect();
141
142    for (source, source_name) in &sources {
143        // Pattern 1: new WebSocket("wss://..." or "ws://...")
144        scan_standard_ws(source, source_name, &mut endpoints, &mut seen_urls);
145
146        // Pattern 2: Socket.IO
147        scan_socketio(source, source_name, &mut endpoints, &mut seen_urls);
148
149        // Pattern 3: SockJS
150        scan_sockjs(source, source_name, &mut endpoints, &mut seen_urls);
151
152        // Pattern 4: SignalR
153        scan_signalr(source, source_name, &mut endpoints, &mut seen_urls);
154    }
155
156    endpoints
157}
158
159/// Check if a domain has a known WebSocket endpoint.
160pub fn has_known_ws(domain: &str) -> bool {
161    ws_platform_registry()
162        .keys()
163        .any(|k| domain.contains(k.as_str()) || k.contains(domain))
164}
165
166/// Get the known WebSocket configuration for a domain, if any.
167pub(crate) fn get_known_ws_config(domain: &str) -> Option<&'static WsPlatformConfig> {
168    ws_platform_registry()
169        .iter()
170        .find(|(k, _)| domain.contains(k.as_str()) || k.contains(domain))
171        .map(|(_, v)| v)
172}
173
174// ── Private scanning functions ──────────────────────────────────────────────
175
176fn scan_standard_ws(
177    source: &str,
178    source_name: &str,
179    endpoints: &mut Vec<WsEndpoint>,
180    seen: &mut HashSet<String>,
181) {
182    static RE: OnceLock<Regex> = OnceLock::new();
183    let re = RE.get_or_init(|| {
184        Regex::new(r#"new\s+WebSocket\(\s*['"]((wss?://[^'"]+))['"]"#).expect("ws regex is valid")
185    });
186
187    for caps in re.captures_iter(source) {
188        let url = caps.get(1).map_or("", |m| m.as_str()).to_string();
189        if !url.is_empty() && seen.insert(url.clone()) {
190            endpoints.push(WsEndpoint {
191                url,
192                protocol: WsProtocol::Raw,
193                auth_method: WsAuth::None,
194                discovered_from: source_name.to_string(),
195                confidence: 0.90,
196            });
197        }
198    }
199}
200
201fn scan_socketio(
202    source: &str,
203    source_name: &str,
204    endpoints: &mut Vec<WsEndpoint>,
205    seen: &mut HashSet<String>,
206) {
207    static RE: OnceLock<Regex> = OnceLock::new();
208    let re = RE.get_or_init(|| {
209        Regex::new(r#"io(?:\.connect)?\(\s*['"]((?:wss?|https?)://[^'"]+)['"]"#)
210            .expect("socketio regex is valid")
211    });
212
213    for caps in re.captures_iter(source) {
214        let url = caps.get(1).map_or("", |m| m.as_str()).to_string();
215        if !url.is_empty() && seen.insert(url.clone()) {
216            // Convert http(s) to ws(s) for Socket.IO
217            let ws_url = url
218                .replace("https://", "wss://")
219                .replace("http://", "ws://");
220            endpoints.push(WsEndpoint {
221                url: ws_url,
222                protocol: WsProtocol::SocketIO,
223                auth_method: WsAuth::Cookie,
224                discovered_from: source_name.to_string(),
225                confidence: 0.85,
226            });
227        }
228    }
229}
230
231fn scan_sockjs(
232    source: &str,
233    source_name: &str,
234    endpoints: &mut Vec<WsEndpoint>,
235    seen: &mut HashSet<String>,
236) {
237    static RE: OnceLock<Regex> = OnceLock::new();
238    let re = RE.get_or_init(|| {
239        Regex::new(r#"new\s+SockJS\(\s*['"]([^'"]+)['"]"#).expect("sockjs regex is valid")
240    });
241
242    for caps in re.captures_iter(source) {
243        let url = caps.get(1).map_or("", |m| m.as_str()).to_string();
244        if !url.is_empty() && seen.insert(url.clone()) {
245            endpoints.push(WsEndpoint {
246                url,
247                protocol: WsProtocol::SockJS,
248                auth_method: WsAuth::Cookie,
249                discovered_from: source_name.to_string(),
250                confidence: 0.85,
251            });
252        }
253    }
254}
255
256fn scan_signalr(
257    source: &str,
258    source_name: &str,
259    endpoints: &mut Vec<WsEndpoint>,
260    seen: &mut HashSet<String>,
261) {
262    static RE: OnceLock<Regex> = OnceLock::new();
263    let re = RE.get_or_init(|| {
264        Regex::new(r#"\.withUrl\(\s*['"]([^'"]+)['"]"#).expect("signalr regex is valid")
265    });
266
267    // Only match if signalR or HubConnectionBuilder is also present
268    if !source.contains("signalR") && !source.contains("HubConnection") {
269        return;
270    }
271
272    for caps in re.captures_iter(source) {
273        let url = caps.get(1).map_or("", |m| m.as_str()).to_string();
274        if !url.is_empty() && seen.insert(url.clone()) {
275            endpoints.push(WsEndpoint {
276                url,
277                protocol: WsProtocol::SignalR,
278                auth_method: WsAuth::Cookie,
279                discovered_from: source_name.to_string(),
280                confidence: 0.85,
281            });
282        }
283    }
284}
285
286fn parse_protocol(s: &str) -> WsProtocol {
287    match s {
288        "raw" => WsProtocol::Raw,
289        "socketio" | "socket.io" => WsProtocol::SocketIO,
290        "sockjs" => WsProtocol::SockJS,
291        "signalr" => WsProtocol::SignalR,
292        _ => WsProtocol::Unknown,
293    }
294}
295
296fn parse_auth(s: &str) -> WsAuth {
297    match s {
298        "cookie" => WsAuth::Cookie,
299        "query_param" | "query_param_token" => WsAuth::QueryParam,
300        "first_message" | "auth_message" => WsAuth::FirstMessage,
301        "header" => WsAuth::Header,
302        "none" => WsAuth::None,
303        _ => WsAuth::None,
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_discover_standard_websocket() {
313        let html = r#"<script>const ws = new WebSocket("wss://api.example.com/stream");</script>"#;
314        let endpoints = discover_ws_endpoints(html, &[], "example.com");
315        assert_eq!(endpoints.len(), 1);
316        assert_eq!(endpoints[0].url, "wss://api.example.com/stream");
317        assert_eq!(endpoints[0].protocol, WsProtocol::Raw);
318    }
319
320    #[test]
321    fn test_discover_socketio() {
322        let js = r#"const socket = io.connect("https://realtime.example.com", {transports: ['websocket']});"#;
323        let endpoints = discover_ws_endpoints("", &[js.to_string()], "example.com");
324        assert_eq!(endpoints.len(), 1);
325        assert_eq!(endpoints[0].protocol, WsProtocol::SocketIO);
326        assert!(endpoints[0].url.starts_with("wss://"));
327    }
328
329    #[test]
330    fn test_discover_sockjs() {
331        let js = r#"var sock = new SockJS("/ws/notifications");"#;
332        let endpoints = discover_ws_endpoints("", &[js.to_string()], "example.com");
333        assert_eq!(endpoints.len(), 1);
334        assert_eq!(endpoints[0].protocol, WsProtocol::SockJS);
335    }
336
337    #[test]
338    fn test_discover_signalr() {
339        let js = r#"
340            const connection = new signalR.HubConnectionBuilder()
341                .withUrl("/hubs/chat")
342                .build();
343        "#;
344        let endpoints = discover_ws_endpoints("", &[js.to_string()], "example.com");
345        assert_eq!(endpoints.len(), 1);
346        assert_eq!(endpoints[0].protocol, WsProtocol::SignalR);
347    }
348
349    #[test]
350    fn test_discover_known_platform() {
351        let endpoints = discover_ws_endpoints("", &[], "slack.com");
352        assert!(!endpoints.is_empty());
353        assert!(endpoints[0].confidence >= 0.9);
354    }
355
356    #[test]
357    fn test_empty_html() {
358        let endpoints = discover_ws_endpoints("", &[], "unknown-domain.com");
359        assert!(endpoints.is_empty());
360    }
361
362    #[test]
363    fn test_deduplication() {
364        let html = r#"
365            <script>new WebSocket("wss://api.example.com/ws");</script>
366            <script>new WebSocket("wss://api.example.com/ws");</script>
367        "#;
368        let endpoints = discover_ws_endpoints(html, &[], "example.com");
369        assert_eq!(endpoints.len(), 1); // Deduplicated
370    }
371
372    #[test]
373    fn test_has_known_ws() {
374        assert!(has_known_ws("slack.com"));
375        assert!(has_known_ws("discord.com"));
376        assert!(!has_known_ws("random-blog.com"));
377    }
378}