cortex_runtime/acquisition/
ws_discovery.rs1use regex::Regex;
14use serde::{Deserialize, Serialize};
15use std::collections::HashSet;
16use std::sync::OnceLock;
17
18const WS_PLATFORMS_JSON: &str = include_str!("ws_platforms.json");
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
22pub enum WsProtocol {
23 Raw,
25 SocketIO,
27 SockJS,
29 SignalR,
31 Unknown,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub enum WsAuth {
38 Cookie,
40 QueryParam,
42 FirstMessage,
44 Header,
46 None,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct WsEndpoint {
53 pub url: String,
55 pub protocol: WsProtocol,
57 pub auth_method: WsAuth,
59 pub discovered_from: String,
61 pub confidence: f32,
63}
64
65#[derive(Debug, Clone, Deserialize)]
68pub(crate) struct WsPlatformConfig {
69 ws_url_pattern: Option<String>,
70 ws_url: Option<String>,
71 protocol: String,
72 auth: String,
73 #[allow(dead_code)]
74 ping: Option<serde_json::Value>,
75 #[allow(dead_code)]
76 auth_message: Option<serde_json::Value>,
77 #[allow(dead_code)]
78 heartbeat: Option<serde_json::Value>,
79 #[allow(dead_code)]
80 send_message: Option<serde_json::Value>,
81}
82
83type WsPlatformRegistry = std::collections::HashMap<String, WsPlatformConfig>;
84
85fn ws_platform_registry() -> &'static WsPlatformRegistry {
86 static REGISTRY: OnceLock<WsPlatformRegistry> = OnceLock::new();
87 REGISTRY.get_or_init(|| serde_json::from_str(WS_PLATFORMS_JSON).unwrap_or_default())
88}
89
90pub fn discover_ws_endpoints(html: &str, js_bundles: &[String], domain: &str) -> Vec<WsEndpoint> {
107 let mut endpoints = Vec::new();
108 let mut seen_urls: HashSet<String> = HashSet::new();
109
110 for (platform_domain, config) in ws_platform_registry() {
112 if domain.contains(platform_domain.as_str()) || platform_domain.contains(domain) {
113 let url = config
114 .ws_url
115 .clone()
116 .or_else(|| config.ws_url_pattern.clone())
117 .unwrap_or_default();
118 if !url.is_empty() && seen_urls.insert(url.clone()) {
119 let protocol = parse_protocol(&config.protocol);
120 let auth_method = parse_auth(&config.auth);
121 endpoints.push(WsEndpoint {
122 url,
123 protocol,
124 auth_method,
125 discovered_from: format!("platform:{platform_domain}"),
126 confidence: 0.95,
127 });
128 }
129 }
130 }
131
132 let sources: Vec<(&str, String)> = std::iter::once((html, "html".to_string()))
134 .chain(
135 js_bundles
136 .iter()
137 .enumerate()
138 .map(|(i, s)| (s.as_str(), format!("js_bundle_{i}"))),
139 )
140 .collect();
141
142 for (source, source_name) in &sources {
143 scan_standard_ws(source, source_name, &mut endpoints, &mut seen_urls);
145
146 scan_socketio(source, source_name, &mut endpoints, &mut seen_urls);
148
149 scan_sockjs(source, source_name, &mut endpoints, &mut seen_urls);
151
152 scan_signalr(source, source_name, &mut endpoints, &mut seen_urls);
154 }
155
156 endpoints
157}
158
159pub fn has_known_ws(domain: &str) -> bool {
161 ws_platform_registry()
162 .keys()
163 .any(|k| domain.contains(k.as_str()) || k.contains(domain))
164}
165
166pub(crate) fn get_known_ws_config(domain: &str) -> Option<&'static WsPlatformConfig> {
168 ws_platform_registry()
169 .iter()
170 .find(|(k, _)| domain.contains(k.as_str()) || k.contains(domain))
171 .map(|(_, v)| v)
172}
173
174fn scan_standard_ws(
177 source: &str,
178 source_name: &str,
179 endpoints: &mut Vec<WsEndpoint>,
180 seen: &mut HashSet<String>,
181) {
182 static RE: OnceLock<Regex> = OnceLock::new();
183 let re = RE.get_or_init(|| {
184 Regex::new(r#"new\s+WebSocket\(\s*['"]((wss?://[^'"]+))['"]"#).expect("ws regex is valid")
185 });
186
187 for caps in re.captures_iter(source) {
188 let url = caps.get(1).map_or("", |m| m.as_str()).to_string();
189 if !url.is_empty() && seen.insert(url.clone()) {
190 endpoints.push(WsEndpoint {
191 url,
192 protocol: WsProtocol::Raw,
193 auth_method: WsAuth::None,
194 discovered_from: source_name.to_string(),
195 confidence: 0.90,
196 });
197 }
198 }
199}
200
201fn scan_socketio(
202 source: &str,
203 source_name: &str,
204 endpoints: &mut Vec<WsEndpoint>,
205 seen: &mut HashSet<String>,
206) {
207 static RE: OnceLock<Regex> = OnceLock::new();
208 let re = RE.get_or_init(|| {
209 Regex::new(r#"io(?:\.connect)?\(\s*['"]((?:wss?|https?)://[^'"]+)['"]"#)
210 .expect("socketio regex is valid")
211 });
212
213 for caps in re.captures_iter(source) {
214 let url = caps.get(1).map_or("", |m| m.as_str()).to_string();
215 if !url.is_empty() && seen.insert(url.clone()) {
216 let ws_url = url
218 .replace("https://", "wss://")
219 .replace("http://", "ws://");
220 endpoints.push(WsEndpoint {
221 url: ws_url,
222 protocol: WsProtocol::SocketIO,
223 auth_method: WsAuth::Cookie,
224 discovered_from: source_name.to_string(),
225 confidence: 0.85,
226 });
227 }
228 }
229}
230
231fn scan_sockjs(
232 source: &str,
233 source_name: &str,
234 endpoints: &mut Vec<WsEndpoint>,
235 seen: &mut HashSet<String>,
236) {
237 static RE: OnceLock<Regex> = OnceLock::new();
238 let re = RE.get_or_init(|| {
239 Regex::new(r#"new\s+SockJS\(\s*['"]([^'"]+)['"]"#).expect("sockjs regex is valid")
240 });
241
242 for caps in re.captures_iter(source) {
243 let url = caps.get(1).map_or("", |m| m.as_str()).to_string();
244 if !url.is_empty() && seen.insert(url.clone()) {
245 endpoints.push(WsEndpoint {
246 url,
247 protocol: WsProtocol::SockJS,
248 auth_method: WsAuth::Cookie,
249 discovered_from: source_name.to_string(),
250 confidence: 0.85,
251 });
252 }
253 }
254}
255
256fn scan_signalr(
257 source: &str,
258 source_name: &str,
259 endpoints: &mut Vec<WsEndpoint>,
260 seen: &mut HashSet<String>,
261) {
262 static RE: OnceLock<Regex> = OnceLock::new();
263 let re = RE.get_or_init(|| {
264 Regex::new(r#"\.withUrl\(\s*['"]([^'"]+)['"]"#).expect("signalr regex is valid")
265 });
266
267 if !source.contains("signalR") && !source.contains("HubConnection") {
269 return;
270 }
271
272 for caps in re.captures_iter(source) {
273 let url = caps.get(1).map_or("", |m| m.as_str()).to_string();
274 if !url.is_empty() && seen.insert(url.clone()) {
275 endpoints.push(WsEndpoint {
276 url,
277 protocol: WsProtocol::SignalR,
278 auth_method: WsAuth::Cookie,
279 discovered_from: source_name.to_string(),
280 confidence: 0.85,
281 });
282 }
283 }
284}
285
286fn parse_protocol(s: &str) -> WsProtocol {
287 match s {
288 "raw" => WsProtocol::Raw,
289 "socketio" | "socket.io" => WsProtocol::SocketIO,
290 "sockjs" => WsProtocol::SockJS,
291 "signalr" => WsProtocol::SignalR,
292 _ => WsProtocol::Unknown,
293 }
294}
295
296fn parse_auth(s: &str) -> WsAuth {
297 match s {
298 "cookie" => WsAuth::Cookie,
299 "query_param" | "query_param_token" => WsAuth::QueryParam,
300 "first_message" | "auth_message" => WsAuth::FirstMessage,
301 "header" => WsAuth::Header,
302 "none" => WsAuth::None,
303 _ => WsAuth::None,
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_discover_standard_websocket() {
313 let html = r#"<script>const ws = new WebSocket("wss://api.example.com/stream");</script>"#;
314 let endpoints = discover_ws_endpoints(html, &[], "example.com");
315 assert_eq!(endpoints.len(), 1);
316 assert_eq!(endpoints[0].url, "wss://api.example.com/stream");
317 assert_eq!(endpoints[0].protocol, WsProtocol::Raw);
318 }
319
320 #[test]
321 fn test_discover_socketio() {
322 let js = r#"const socket = io.connect("https://realtime.example.com", {transports: ['websocket']});"#;
323 let endpoints = discover_ws_endpoints("", &[js.to_string()], "example.com");
324 assert_eq!(endpoints.len(), 1);
325 assert_eq!(endpoints[0].protocol, WsProtocol::SocketIO);
326 assert!(endpoints[0].url.starts_with("wss://"));
327 }
328
329 #[test]
330 fn test_discover_sockjs() {
331 let js = r#"var sock = new SockJS("/ws/notifications");"#;
332 let endpoints = discover_ws_endpoints("", &[js.to_string()], "example.com");
333 assert_eq!(endpoints.len(), 1);
334 assert_eq!(endpoints[0].protocol, WsProtocol::SockJS);
335 }
336
337 #[test]
338 fn test_discover_signalr() {
339 let js = r#"
340 const connection = new signalR.HubConnectionBuilder()
341 .withUrl("/hubs/chat")
342 .build();
343 "#;
344 let endpoints = discover_ws_endpoints("", &[js.to_string()], "example.com");
345 assert_eq!(endpoints.len(), 1);
346 assert_eq!(endpoints[0].protocol, WsProtocol::SignalR);
347 }
348
349 #[test]
350 fn test_discover_known_platform() {
351 let endpoints = discover_ws_endpoints("", &[], "slack.com");
352 assert!(!endpoints.is_empty());
353 assert!(endpoints[0].confidence >= 0.9);
354 }
355
356 #[test]
357 fn test_empty_html() {
358 let endpoints = discover_ws_endpoints("", &[], "unknown-domain.com");
359 assert!(endpoints.is_empty());
360 }
361
362 #[test]
363 fn test_deduplication() {
364 let html = r#"
365 <script>new WebSocket("wss://api.example.com/ws");</script>
366 <script>new WebSocket("wss://api.example.com/ws");</script>
367 "#;
368 let endpoints = discover_ws_endpoints(html, &[], "example.com");
369 assert_eq!(endpoints.len(), 1); }
371
372 #[test]
373 fn test_has_known_ws() {
374 assert!(has_known_ws("slack.com"));
375 assert!(has_known_ws("discord.com"));
376 assert!(!has_known_ws("random-blog.com"));
377 }
378}